In [11]:
!pip install git+https://github.com/rwightman/pytorch-image-models

In [22]:
import numpy as np
import pandas as pd
import os
import glob

import IPython
from IPython.display import FileLink

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Flatten

from torch.utils.data import DataLoader, Dataset

from torchvision import datasets, transforms
from torchvision import models

import PIL

import timm
from pprint import pprint

Config

In [23]:
class Config:
    IMG_SIZE_H = 224
    IMG_SIZE_W = 224
    BATCH_SIZE = 128
    DEVICE = torch.device('cuda')
#     MODEL_LOAD_PATH = '/content/best_model.pt'
    test_images = '../input/trash-containers/test_dataset_test'
    weights_path = '../input/trash-containers-cls-weights/swin224_ml_0.91.pt'

class ModelConfig:
    model_name = 'swin_large_patch4_window7_224'
    linear_layer_input_size = 1536

In [28]:
transform_test = transforms.Compose([
        #  transforms.CenterCrop(2048),
         transforms.Resize((Config.IMG_SIZE_H, Config.IMG_SIZE_W)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
])

In [29]:
class CustomInferenceDataset(torch.utils.data.Dataset):
    def __init__(self, images_folder, test_transform = None):
        self.images_folder = images_folder
        self.transform = test_transform
        self.list_files = sorted(glob.glob(f'{images_folder}/*.jp*'))

    def __len__(self):
        return len(self.list_files)
                                           
    def __getitem__(self, index):
        file_name = self.list_files[index]
        image = PIL.Image.open(file_name)
        if self.transform is not None:
            image = self.transform(image)
        return image, file_name.split('/')[-1]

In [30]:
test_ds = CustomInferenceDataset(Config.test_images, test_transform=transform_test)
test_loader = DataLoader(test_ds, batch_size=Config.BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)

Model

In [32]:
class SWIN(nn.Module):
    def __init__(self, fc_layers=True, fc_layer_sz=256, base_model_output_size=1536):
        super().__init__()
        
        self.fc_layers = fc_layers
        self.fc_layer_sz = fc_layer_sz
        self.base_model_output_size = base_model_output_size
    
        self.swin = timm.create_model(ModelConfig.model_name, in_chans = 3, pretrained = True, num_classes=0)
        print(ModelConfig.model_name)
        for param in self.swin.parameters():
            param.requires_grad = False

#         self.flatten = Flatten()
        self.fc = nn.Linear(self.base_model_output_size, 3)
        self.fc1 = nn.Linear(self.base_model_output_size, self.fc_layer_sz)
        self.fc2 = nn.Linear(self.fc_layer_sz, 3)
        
        self.batchnorm = nn.BatchNorm1d(self.fc_layer_sz)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
    
    def fc_layer(self, x, after_first_layer=False):
        if after_first_layer:
            return self.fc1(x)
        else:
            x = self.relu(self.fc1(x))
            x = self.batchnorm(x)
            x = self.dropout(x)
            return self.fc2(x)
    
    def forward(self, x):
        x = self.swin(x)
        
        if self.fc_layers:
            return self.fc_layer(x)
        else:
            x = self.fc(x)
            return x
    
    def get_features(self, x, after_fc=False):
        if after_fc and self.fc_layers:
            x = self.swin(x)
            return self.fc_layer(x, True)
        if after_fc and not self.fc_layers:
            return self.forward(x)
        else:
            return self.swin(x)
        
    def set_parameter_requires_grad(self, freeze: bool): # unused
        for param in self.swin.parameters():
            param.requires_grad = not freeze
        if not freeze:
            for name ,child in (self.swin.named_children()):
                if name.find('norm') != -1: # norm BatchNorm
                    for param in child.parameters():
                        param.requires_grad = False

In [33]:
torch.cuda.empty_cache()
import gc
gc.collect()

In [34]:
net = SWIN().to(Config.DEVICE)

model_path = Config.weights_path
net.load_state_dict(torch.load(model_path)['model_state_dict'])

In [53]:
preds = []
file_names_all = []

with torch.no_grad():
    net.eval()
    for batch in tqdm(test_loader):
        images, file_names = batch
        images = images.to(Config.DEVICE)
        file_names_all.extend(file_names)
        
        y_pred = net(images)
        y_pred = torch.sigmoid(y_pred)
        y_pred = torch.argmax(y_pred, 1)
        y_pred = y_pred.detach().cpu().numpy().tolist()
        preds.extend(y_pred)

In [54]:
submit_csv_file_name = 'swin_base_multilabel_val_spl.csv'

submit = pd.DataFrame({'ID_img': [filen.split('.')[0] for filen in file_names_all], 'class': preds})
submit.to_csv(submit_csv_file_name, index=False)
submit.head(10)

In [37]:
display(IPython.display.Audio(url="https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg", autoplay=True))

In [39]:
FileLink(f'./{submit_csv_file_name}')