In [1]:
!pip install git+https://github.com/rwightman/pytorch-image-models
!pip install -q -U albumentations

In [2]:
import numpy as np
import pandas as pd
import os
import glob

import IPython
from IPython.display import FileLink

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Flatten
from torch.optim.lr_scheduler import StepLR

from torch.utils.data import DataLoader, Dataset, Subset

from torchvision import datasets, transforms
from torchvision import models

import albumentations as A
from albumentations.pytorch import ToTensorV2


import PIL

import timm
from pprint import pprint

from sklearn.preprocessing import label_binarize

from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics.pairwise import cosine_similarity

from sklearn.utils.class_weight import compute_class_weight
from sklearn.manifold import TSNE
import umap
from sklearn.preprocessing import StandardScaler

from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC

In [3]:
train_csv = pd.read_csv('../input/trash-containers/train_dataset_train/train.csv')

# EDA

In [4]:
sns.countplot(x = "class" , data  = train_csv)

In [5]:
fig, axs = plt.subplots(2, 6, figsize=(22,8))
fig.suptitle(f'Нет мусорок {" "*105} Есть мусорки {" "*105} другое', fontsize=14)

train_path = '../input/trash-containers/train_dataset_train/train'

for i, name in zip(range(4), train_csv[ train_csv['class'] == 0 ].sample(4, random_state=42)['ID_img']):
    axs[i // 2, (i % 2)].imshow(plt.imread(f"{train_path}/{name}"))
    axs[i // 2, (i % 2)].axis('off')

for i, name in zip(range(4), train_csv[ train_csv['class'] == 1 ].sample(4, random_state=42)['ID_img']):
    axs[i // 2, (i % 2)+2].imshow(plt.imread(f"{train_path}/{name}"))
    axs[i // 2, (i % 2)+2].axis('off')
    
for i, name in zip(range(4), train_csv[ train_csv['class'] == 2 ].sample(4, random_state=42)['ID_img']):
    axs[i // 2, (i % 2)+4].imshow(plt.imread(f"{train_path}/{name}"))
    axs[i // 2, (i % 2)+4].axis('off')

fig.tight_layout()
fig.subplots_adjust(top=0.88)

### Data transform config

In [6]:
class ConfigDataTransform:
    train_csv = pd.read_csv('../input/trash-containers/train_dataset_train/train.csv')
    remove_dirty_labels = False
    replace_dirty_labels = True

Remove dirty examples from train

In [7]:
if ConfigDataTransform.remove_dirty_labels or ConfigDataTransform.replace_dirty_labels:
    r_from_0 = ['220304014444_93d324658e78782c37f3bce0c65e03b2.jpg', '220307012903_c2e25c71975244587fb02517aec28abb.jpg',\
                '220307023015_d7805fde0aeda728f0136ec7482cc43f.jpg', '220307023034_9d7ebe271649a6ebbc3494c026126869.jpg',\
                '220307024716_902e232dfe9c5ac55643703abcbf1d11.jpg', '220307123436_a755745087e8d1f6efb7dde875b3f1be.jpg']
    r_from_1 = ['220301070625_1be30d27fcfc119ed8dee6037a9ab385.jpg','220301124526_ba6ffcfd54b8a2e2d7120082abf89be8.jpg',\
               '220302013218_0dc7f7864ff931e2f0e2673f89cfb47c.jpg','220303054740_9daab175576b21466a151bc4b9a0536f.jpg',\
               '220304060156_51b2fe94598d626677f65085635568fa.jpg','220304063326_d00d7a35c1490bf3ce4daeb58f8eaf7e.jpg',]
    
    train_csv = ConfigDataTransform.train_csv
    r_from_0_idxs = train_csv[train_csv['ID_img'].isin(r_from_0)].index
    r_from_1_idxs = train_csv[train_csv['ID_img'].isin(r_from_1)].index
    
    if ConfigDataTransform.remove_dirty_labels:
        train_csv = train_csv.drop(r_from_0_idxs)
        train_csv = train_csv.drop(r_from_1_idxs)
        
    if ConfigDataTransform.replace_dirty_labels:
        train_csv.loc[r_from_0_idxs, 'class'] = 2
        train_csv.loc[r_from_1_idxs, 'class'] = 2
        
    train_csv.to_csv('./train.csv', index=False);

### Config

In [69]:
class Config:
    IMG_SIZE_H = 224
    IMG_SIZE_W = 224
    BATCH_SIZE = 128
    EPOCHS = 25
    FT_EPOCHS = 31
    VAL_IN_EPOCHS = 1
    DEVICE = torch.device('cuda')
#     MODEL_LOAD_PATH = '/content/best_model.pt'
    train_images = '../input/trash-containers/train_dataset_train/train'
    test_images = '../input/trash-containers/test_dataset_test'
    val_csv_path = './val.csv'
    load_val_from_csv = False
    train_csv_path = './train.csv'
    train = False
    fine_tune = False
    val = True # val dataset is mandatory for training
    inference = True
    submit = True
    load_weights = True
    weights_path_var = 'old' # old for load from Config.weights_path new for load from Config.new_current_weights_path
    weights_path = '../input/trash-containers-cls-weights/swin224_ml_0.91.pt'
    new_current_weights_path = None
    augmentator='torchvision' #torchvision or albs
    remove_old_weights = True
    train_val_split = True # need to set train_val_split or load_val_from_csv
    val_size = 0.25

class ModelConfig:
    model_name = 'swin_large_patch4_window7_224'
    linear_layer_input_size = 1536
#     base_lr = 3e-3
#     ft_lr = 3e-5
#     opt = 'adamw'
#     loss = ''

if ConfigDataTransform.remove_dirty_labels or ConfigDataTransform.replace_dirty_labels:
    Config.train_csv_path = './train.csv'
else:
    Config.train_csv_path = '../input/trash-containers/train_dataset_train/train.csv'

Class weights

In [37]:
class_weights = compute_class_weight('balanced', classes=[0,1,2], y=pd.read_csv(Config.train_csv_path)['class'])

# class_sample_count = np.unique(train_csv['class'].astype(int), return_counts=True)[1]
# weight = 1. / class_sample_count
# samples_weight = weight[train_csv['class'].astype(int)]
# samples_weight = torch.from_numpy(samples_weight)

# Augmentations

In [18]:
if Config.augmentator == 'torchvision':
    transform_train = transforms.Compose([
            #  transforms.CenterCrop(2048),
            #  transforms.RandomResizedCrop(2048),
             transforms.Resize((Config.IMG_SIZE_H, Config.IMG_SIZE_W)),
             transforms.RandomHorizontalFlip(p=0.5),
             transforms.RandomVerticalFlip(p=0.5),
             transforms.RandomRotation(90),
    #          transforms.RandomAffine(10, translate=(0, 0.1), scale=(1, 1), shear=5, interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
    #          transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0),
    #          transforms.RandomRotation(180),
    #          transforms.RandomRotation(270),
    #          transforms.RandomVerticalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
            #  transforms.CenterCrop(2048),
             transforms.Resize((Config.IMG_SIZE_H, Config.IMG_SIZE_W)),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
    ])

In [19]:
if Config.augmentator == 'albs':
    transform_train = A.Compose([
            A.LongestMaxSize(max_size=Config.IMG_SIZE_H, interpolation=1),
    #         A.PadIfNeeded(min_height=Config.IMG_SIZE_H, min_width=Config.IMG_SIZE_W, border_mode=0, value=(0,0,0)),
            A.PadIfNeeded(Config.IMG_SIZE_H, Config.IMG_SIZE_W, p=1.0),
    #         геометрические преобразования
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
            A.Rotate(limit=[0, 90], p=0.5),
    #         цветовые преобразования
            A.RandomBrightnessContrast(p=0.3),
            # A.ToGray(p=0.1),
            A.CLAHE(p=0.3),
            A.FancyPCA(p=0.3),
            # A.Blur(),
            # A.GaussNoise(),
            # A.InvertImg(),
            # A.RGBShift(p=1),  
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

    transform_test = transform_train = A.Compose([
            A.LongestMaxSize(max_size=Config.IMG_SIZE_H, interpolation=1),
            A.PadIfNeeded(Config.IMG_SIZE_H, Config.IMG_SIZE_W, p=1.0),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

# DataSet

In [20]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, images_folder, transform = None):
        self.df = pd.read_csv(csv_path, index_col=0)
        self.images_folder = images_folder
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        filename = self.df.iloc[index].name
        label = int(self.df.iloc[index, 0])
        image = PIL.Image.open(os.path.join(self.images_folder, filename))
        
        if self.transform is not None:
            if Config.augmentator == "torchvision":
                image = self.transform(image)
                return image, label
            if Config.augmentator == "albs":
                image = self.transform(image=np.array(image))
                return image['image'], label
        return image, label

In [21]:
class CustomInferenceDataset(torch.utils.data.Dataset):
    def __init__(self, images_folder, test_transform = None):
        self.images_folder = images_folder
        self.transform = test_transform
        self.list_files = sorted(glob.glob(f'{images_folder}/*.jp*'))

    def __len__(self):
        return len(self.list_files)
                                           
    def __getitem__(self, index):
        file_name = self.list_files[index]
        image = PIL.Image.open(file_name)
        if self.transform is not None:
            if Config.augmentator == "torchvision":
                image = self.transform(image)
                return image, file_name.split('/')[-1]
            if Config.augmentator == "albs":
                image = self.transform(image=np.array(image))
                return image['image'], file_name.split('/')[-1]
        return image, file_name.split('/')[-1]

In [38]:
train_ds = CustomDataset(Config.train_csv_path, Config.train_images, transform=transform_train)

Split to train/val

In [39]:
def subset_ind(dataset, ratio: float):
    return np.random.choice(len(dataset), size=int(ratio*len(dataset)), replace=False)

if Config.train_val_split:
    val_size = Config.val_size
    val_inds = subset_ind(train_ds, val_size)

    train_dataset = Subset(train_ds, [i for i in range(len(train_ds)) if i not in val_inds])
    val_dataset = Subset(train_ds, val_inds)
    train_ds = train_dataset
    val_ds = val_dataset

    print(f'training size: {len(train_dataset)}\nvalidation size: {len(val_dataset)}')

In [32]:
if Config.load_val_from_csv:
    val_ds = CustomDataset(Config.val_csv_path, Config.test_images, transform=transform_test)
    
if Config.inference:
    test_ds = CustomInferenceDataset(Config.test_images, test_transform=transform_test)

visualize first 2 images from ds

In [41]:
def visualize_pic(ds, sample_count=2):
    for i in range(0, sample_count):
        image = torch.permute(ds[i][0], (1,2,0))
        plt.imshow(image.numpy())
        plt.show()

In [42]:
# train_ds[0]
visualize_pic(train_ds)
# len(train_ds)

Dataloaders

In [46]:
train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)
if Config.inference:
    test_loader = DataLoader(test_ds, batch_size=Config.BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)

# Model

In [49]:
model_names = timm.list_models(pretrained=True)
# pprint(model_names)

In [50]:
models_names_s = ['swin_large_patch4_window12_384_in22k', 'swin_large_patch4_window7_224', 'swin_base_patch4_window12_384_in22k', 'vit_small_patch16_384', 'swinv2_base_window8_256']

In [52]:
model = timm.create_model(models_names_s[1], in_chans = 3, pretrained = True, num_classes=0, global_pool='avg') # global_pool='catavgmax' # features_only=True
print(model.default_cfg)

In [53]:
dummy_image = torch.randn(1, 3, Config.IMG_SIZE_H, Config.IMG_SIZE_W)
model.forward_features(dummy_image).shape
model.forward(dummy_image).shape

In [54]:
class SWIN(nn.Module):
    def __init__(self, fc_layers=True, fc_layer_sz=256, base_model_output_size=1536):
        super().__init__()
        
        self.fc_layers = fc_layers
        self.fc_layer_sz = fc_layer_sz
        self.base_model_output_size = base_model_output_size
    
        self.swin = timm.create_model(ModelConfig.model_name, in_chans = 3, pretrained = True, num_classes=0)
        print(ModelConfig.model_name)
        for param in self.swin.parameters():
            param.requires_grad = False

#         self.flatten = Flatten()
        self.fc = nn.Linear(self.base_model_output_size, 3)
        self.fc1 = nn.Linear(self.base_model_output_size, self.fc_layer_sz)
        self.fc2 = nn.Linear(self.fc_layer_sz, 3)
        
        self.batchnorm = nn.BatchNorm1d(self.fc_layer_sz)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
    
    def fc_layer(self, x, after_first_layer=False):
        if after_first_layer:
            return self.fc1(x)
        else:
            x = self.relu(self.fc1(x))
            x = self.batchnorm(x)
            x = self.dropout(x)
            return self.fc2(x)
    
    def forward(self, x):
        x = self.swin(x)
        
        if self.fc_layers:
            return self.fc_layer(x)
        else:
            x = self.fc(x)
            return x
    
    def get_features(self, x, after_fc=False):
        if after_fc and self.fc_layers:
            x = self.swin(x)
            return self.fc_layer(x, True)
        if after_fc and not self.fc_layers:
            return self.forward(x)
        else:
            return self.swin(x)
        
    def set_parameter_requires_grad(self, freeze: bool): # unused
        for param in self.swin.parameters():
            param.requires_grad = not freeze
        if not freeze:
            for name ,child in (self.swin.named_children()):
                if name.find('norm') != -1: # norm BatchNorm
                    for param in child.parameters():
                        param.requires_grad = False

# Train

In [55]:
def save_model(epoch, net, optimizer, loss, metrics):
    #                     torch.save(net.state_dict(), f'./best_model_{epoch}_{str(np.round(metrics, 3))}.pt')
    torch.save({
        'epoch': epoch, 'model_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss, }, f'./best_model_{epoch}_{str(np.round(metrics, 3))}.pt')
    Config.new_current_weights_path = f'./best_model_{epoch}_{str(np.round(metrics, 3))}.pt'

In [56]:
def train_batch_step(batch, device, optimizer, loss_fn, epoch_loss, train_size, train_pred):
    X_batch, y_batch = batch

    X_batch = X_batch.to(device)
    y_batch = y_batch.to(device)

    optimizer.zero_grad()
    y_pred = net(X_batch)

    loss1 = loss_fn[0](y_pred, F.one_hot(y_batch, num_classes=3).float())
    loss2 = loss_fn[1](y_pred, F.one_hot(y_batch, num_classes=3).float())
    loss = loss1 + loss2
    
    loss.backward()
    optimizer.step()

    y_pred = torch.sigmoid(y_pred)
    y_pred = torch.argmax(y_pred, 1)

    y_batch = y_batch.view(-1)
    y_pred = y_pred.view(-1)

    epoch_loss += y_pred.shape[0] * loss.item()

    train_size += y_pred.size(0)
    train_pred += torch.sum(y_pred == y_batch)
    
    return epoch_loss, train_size, train_pred, loss
    
def val_batch_step(batch, device, net, loss_fn, epoch_val_loss, y_true_a, y_pred_a):
    images, y_true = batch
    images = images.to(device)
    y_true = y_true.to(device)
    y_pred = net(images)
    loss1 = loss_fn[0](y_pred, F.one_hot(y_true, num_classes=3).float())
    loss2 = loss_fn[1](y_pred, F.one_hot(y_true, num_classes=3).float())
    loss = loss1 + loss2

    y_pred = torch.sigmoid(y_pred)
    y_pred = torch.argmax(y_pred, 1)
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    y_pred = y_pred.detach().cpu().numpy()
    y_true = y_true.detach().cpu().numpy() 

    epoch_val_loss += y_pred.shape[0] * loss.item()

    y_true_a.extend(np.eye(3)[y_true])
    y_pred_a.extend(np.eye(3)[y_pred])
    
    return epoch_val_loss

In [57]:
def train(net, loss_fn, optimizer, scheduler, train_loader,val_loader, metrics_log, n_epoch=10):
    best_metrics = 0
    device = Config.DEVICE
    train_acc_log = []
    val_roc_auc = []
    train_loss = []
    val_loss = []

    for epoch in range(n_epoch):
        print(f'Epoch {epoch + 1}')
        net.train()
        
        train_size = 0
        train_pred = 0.
        epoch_loss = 0.0
        epoch_val_loss = 0.0
        
        train_dataiter = iter(train_loader)
        for i, batch in enumerate(tqdm(train_dataiter)):
            epoch_loss, train_size, train_pred, loss = train_batch_step(batch, device, optimizer, loss_fn, epoch_loss, train_size, train_pred)
            
        scheduler.step()
        
        train_acc = (train_pred / train_size).detach().cpu().numpy()
        train_loss_ = epoch_loss / len(train_dataiter)
        
        train_acc_log.append(train_acc)
        train_loss.append(train_loss_)
        
        print('train loss: ', train_loss_)
        print('train acc: ', train_acc)

        if (epoch % Config.VAL_IN_EPOCHS) == 0:
            with torch.no_grad():
                net.eval()
                
                y_true_a = []
                y_pred_a = []
                
                for batch in val_loader:
                    epoch_val_loss = val_batch_step(batch, device, net, loss_fn, epoch_val_loss, y_true_a, y_pred_a)

                roc_auc = roc_auc_score(y_true_a, y_pred_a, multi_class='ovo')
                val_loss_v = epoch_val_loss / len(val_loader)
                print('val loss: ', val_loss_v)
                print(f'roc_auc: {roc_auc}')
                
                val_roc_auc.append(roc_auc)
                val_loss.append(val_loss_v)
                
                if roc_auc > best_metrics:
                    print('New best model with test roc auc:', roc_auc)
                    save_model(epoch, net, optimizer, loss, roc_auc)
                    best_metrics = roc_auc
    
    metrics_log['train_acc'] = list(np.squeeze(train_acc_log))
    metrics_log['val_roc_auc'] = val_roc_auc
    metrics_log['train_loss'] = train_loss
    metrics_log['val_loss'] = val_loss
    
    return net

In [58]:
torch.cuda.empty_cache()
import gc
gc.collect()

if Config.remove_old_weights:
    for f in glob.glob("./*.pt"):
        os.remove(f)

In [59]:
def get_pos_weights():
    class_weights_ = [class_weights[0] / class_weights[1]+class_weights[2],
    class_weights[1] / class_weights[0]+class_weights[2],
    class_weights[2] / class_weights[0]+class_weights[1]]
    return class_weights_

class_weights_ = get_pos_weights()

In [70]:
net = SWIN().to(Config.DEVICE)

# from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
# CosineAnnealingLR(optimizer, T_max=5, )
# CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1)
metrics_log = {}

if Config.train or Config.val:
    lr = 1e-3
    loss_fn = [torch.nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor(class_weights_).to(Config.DEVICE)), torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(Config.DEVICE))]
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = StepLR(optimizer, step_size=2, gamma=0.85)
    
    if Config.train:
        net = train(net, loss_fn, optimizer, scheduler, train_loader, val_loader, metrics_log, n_epoch=Config.EPOCHS)

In [62]:
if Config.train:
    sns.lineplot(data=pd.DataFrame({'train_acc': metrics_log['train_acc'], 'val_roc_auc': metrics_log['val_roc_auc']})).set_xticks(range(len(metrics_log['train_acc'])));

In [63]:
if Config.train:
    sns.lineplot(data = pd.DataFrame({'train_loss': metrics_log['train_loss'], 'val_loss': metrics_log['val_loss']})).set_xticks(range(len(metrics_log['train_acc'])));

Fine-tuning

In [64]:
if Config.fine_tune:
#     opt_params = optimizer.state_dict()
    net.set_parameter_requires_grad(freeze=False)

    lr = 5e-4
#     optimizer = torch.optim.Adam(net.parameters(), lr=lr)
#     optimizer.load_state_dict(opt_params)
    
    optimizer.param_groups[0]['lr'] = lr
    
    if Config.load_weights:
        optimizer.load_state_dict(torch.load(Config.new_current_weights_path)['optimizer_state_dict'])
#         optimizer.load_state_dict(torch.load('./best_model_0_0.876.pt')['optimizer_state_dict'])
        optimizer.param_groups[0]['lr'] = 5e-4

    net = train(net, loss_fn, optimizer, train_loader, val_loader, n_epoch=Config.FT_EPOCHS)

### Load best model weights

In [71]:
if Config.weights_path_var == 'old':
    model_path = Config.weights_path
elif Config.weights_path_var == 'new':
    model_path = Config.new_current_weights_path
else:
    print('incorrect var value')
    
print(model_path)

if Config.load_weights:
    net.load_state_dict(torch.load(model_path)['model_state_dict'])

# Val

In [72]:
custom_threshold = False

def custom_threshold(preds):
    new_preds = torch.zeros(preds.shape[0], dtype=torch.int8)
    for i, pred in enumerate(preds):
        if (pred[0] > 0.3) & (pred[1] < 0.6) & (pred[2] < 0.6):
            new_preds[i] = 0
        else:
            new_preds[i] = torch.argmax(pred)
        if (pred[2] > 0.3) & (pred[1] < 0.6) & (pred[0] < 0.6):
            new_preds[i] = 2
        else:
            new_preds[i] = torch.argmax(pred)
    return new_preds
        
with torch.no_grad():
    net.eval()
    epoch_val_loss = 0.0
    
    y_true_a = []
    y_pred_a = []
    y_pred_probas = []
    
    for batch in val_loader:
        x, y = batch
        x = x.to(Config.DEVICE)
        y = y.to(Config.DEVICE)
        y_pred = net(x)
        loss1 = loss_fn[0](y_pred, F.one_hot(y, num_classes=3).float())
        loss2 = loss_fn[1](y_pred, F.one_hot(y, num_classes=3).float())
        loss = loss1 + loss2
        
        y_pred = torch.sigmoid(y_pred)
        y_pred_probas.extend(y_pred.detach().cpu().numpy())
        
        # todo: add custom threshold
        if custom_threshold:
            y_pred = custom_threshold(y_pred)
        else:
            y_pred = torch.argmax(y_pred, 1)
        
        y_true = y.detach().cpu().numpy() 
        y_pred = y_pred.view(-1).detach().cpu().numpy()

        y_true_a.extend(np.eye(3)[y_true])
        y_pred_a.extend(np.eye(3)[y_pred])
        
        epoch_val_loss += y_pred.shape[0] * loss.item()
        
    roc_auc = roc_auc_score(y_true_a, y_pred_a, multi_class='ovo')
    print('val loss: ', epoch_val_loss / len(val_loader))
    print(f'roc_auc: {roc_auc}')

In [None]:
# mask = np.argmax(y_true_a, 1) != np.argmax(y_pred_a, 1)
# y_true_v = np.argmax(y_true_a, 1)
# y_pred_probas = np.squeeze(y_pred_probas)
# y_pred_probas[mask].shape

In [None]:
# y_true_v[mask][y_true_v[mask] == 1]
# np.where(y_true_v[mask] == 2)

In [None]:
# y_pred_probas[mask]

In [73]:
print(classification_report(y_true_a, y_pred_a))

matrix = confusion_matrix(np.argmax(y_true_a, axis=1), np.argmax(y_pred_a, axis=1))
matrix.diagonal()/matrix.sum(axis=0)

plt.figure(figsize=(8, 6), dpi=80)
sns.heatmap(matrix, annot=True, fmt='d')

In [74]:
print(matrix)

# Inference

In [75]:
preds = []
file_names_all = []

with torch.no_grad():
    net.eval()
    metrics = []
    for batch in tqdm(test_loader):
        images, file_names = batch
        images = images.to(Config.DEVICE)
        file_names_all.extend(file_names)
        
        y_pred = net(images)
        y_pred = torch.sigmoid(y_pred)
        y_pred = torch.argmax(y_pred, 1)
        y_pred = y_pred.detach().cpu().numpy().tolist()
        preds.extend(y_pred)

## Submit

In [81]:
if Config.submit:
    
    submit_csv_file_name = 'swin_base_multilabel_val_spl.csv'

    submit = pd.DataFrame({'ID_img': [filen.split('.')[0] for filen in file_names_all], 'class': preds})
    submit.to_csv(submit_csv_file_name, index=False)
    print(submit.head(5))

In [78]:
display(IPython.display.Audio(url="https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg", autoplay=True))

In [82]:
# from https://www.kaggle.com/getting-started/168312
FileLink(f'./{submit_csv_file_name}')

<!-- <a href=""> Download submit file </a> (past file name here) -->