## Install + Import Packages
Internet is disabled, so .whl files are uploaded to a Kaggle dataset

In [None]:
!pip install python-box --no-index --find-links=file:///kaggle/input/python-box-dicts

In [None]:
import gc
import glob
import os

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision import transforms
from sklearn.model_selection import ShuffleSplit
from box import Box

import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import EarlyStopping, GPUStatsMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import LightningModule, LightningDataModule

In [None]:
config = Box({
    'seed': 42,
    'epochs': {
        'pretrain': 5,
        'meta_ps': 15,
        'finetune': 5,
    },
    'trainer': {
        'gpus': 1,
        'precision': 16,
        'num_sanity_val_steps': 0,
        'fast_dev_run': False,
    },
    'model_name': 'swin_tiny_patch4_window7_224',
    'lr': 1e-5,
    'batch_size': 64,
    'splits': 1,
    'val_size': 0.1,
})

seed_everything(config.seed)

## Datasets & Dataloaders (w/ Augmentation)

In [None]:
# For this competition in particular, some augmentations may have a significant impact on the
# "correct" labels for each image (e.g. a squashed image may get a lower score) so only a minimal
# number of transformations are used for now. 
my_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}

In [None]:
class PawpularityDataset(Dataset):
    def __init__(self, df, image_size=224, test=False):
        self.image_paths = df['Id'].apply(lambda x: os.path.join('../input/petfinder-pawpularity-score', 'test' if test else 'train', x + '.jpg')).values
        self.labels = None
        if 'Pawpularity' in df.keys():
            self.labels = df['Pawpularity'].values
        self.transform = transforms.Resize([image_size, image_size])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = read_image(self.image_paths[idx])
        image = self.transform(image)
        if self.labels is not None:
            label = self.labels[idx]
            return image, label
        return image

    
class UnlabeledDataset(Dataset):
    def __init__(self, image_paths, image_size=224):
        self.image_paths = image_paths
        self.transform = transforms.Resize([image_size, image_size])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image = read_image(self.image_paths[idx])
        except Exception as e:
            print(f'Failed to load {self.image_paths[idx]} normally, perhaps image has an alpha channel?')
            from PIL import Image
            image = Image.open(self.image_paths[idx])
            image = torch.tensor(np.moveaxis(np.array(image)[:, :, :3], 2, 0))
        image = self.transform(image)
        return image


# Datamodule for semi-supervised training
class PawpularitySSLDataModule(LightningDataModule):
    def __init__(self, labeled_data_dir, split: int, unlabeled_globs=None):
        super().__init__()
        self.labeled_data_dir = labeled_data_dir
        self.unlabeled_globs = unlabeled_globs
        self.split = split
    
    def prepare_data(self):
        if self.unlabeled_globs == None:
            return
        datasets = []
        for globstr in self.unlabeled_globs:
            image_paths = glob.glob(globstr)
            if len(image_paths) == 0:
                raise FileNotFoundError
            datasets.append(UnlabeledDataset(image_paths))
        self.unlabeled_dataset = torch.utils.data.ConcatDataset(datasets)
    
    def setup(self, stage=None):
        if stage in (None, 'fit'):
            if self.split == 0:
                df = pd.read_csv(os.path.join(self.labeled_data_dir, 'train.csv'))
                ss = ShuffleSplit(n_splits=config.splits, test_size=config.val_size, random_state=config.seed)
                for i, (train_idxs, valid_idxs) in enumerate(ss.split(df)):
                    df.loc[valid_idxs, 'Split'] = i
                df.to_csv('train_splits.csv', index=False)
            else:
                df = pd.read_csv('train_splits.csv')
            self.labeled_train_dataset = PawpularityDataset(df[df['Split'] != self.split])
            self.labeled_val_dataset = PawpularityDataset(df[df['Split'] == self.split])
            
        if stage in (None, 'test'):
            test_df = pd.read_csv(os.path.join(self.labeled_data_dir, 'test.csv'))
            self.test_dataset = PawpularityDataset(test_df, test=True)
            
    def train_dataloader(self):
        labeled = DataLoader(
            self.labeled_train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=2,
            drop_last=True,
        )
        
        unlabeled = DataLoader(
            self.unlabeled_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=2,
            drop_last=True
        )
        
        return {'labeled': labeled, 'unlabeled': unlabeled}
    
    def val_dataloader(self):
        return DataLoader(
            self.labeled_val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=2,
            drop_last=False,
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=2,
            drop_last=False
        )


class CanonicalSupervisedDataModule(LightningDataModule):
    def __init__(self, data_dir, split):
        super().__init__()
        self.data_dir = data_dir
        self.split = split
    
    def setup(self, stage=None):
        if stage in (None, 'fit'):
            if self.split == 0:
                df = pd.read_csv(os.path.join(self.data_dir, 'train.csv'))
                ss = ShuffleSplit(n_splits=config.splits, test_size=config.val_size, random_state=config.seed)
                for i, (train_idxs, valid_idxs) in enumerate(ss.split(df)):
                    df.loc[valid_idxs, 'Split'] = i
                df.to_csv('train_splits.csv', index=False)
            else:
                df = pd.read_csv('train_splits.csv')
            self.train_dataset = PawpularityDataset(df[df['Split'] != self.split])
            self.val_dataset = PawpularityDataset(df[df['Split'] == self.split])
            
        if stage in (None, 'test'):
            test_df = pd.read_csv(os.path.join(self.data_dir, 'test.csv'))
            self.test_dataset = PawpularityDataset(test_df, test=True)
            
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=2,
            drop_last=True,
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=2,
            drop_last=False,
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=2,
            drop_last=False
        )

## Defining PyTorch Lightning Systems

### Supervised Learning

In [None]:
class SwinModel(pl.LightningModule):
    def __init__(self, pretrained, split):
        super().__init__()
        
        self.split = split
        
        self.backbone = timm.create_model(
            config.model_name,
            num_classes=0,
            in_chans=3,
        )
        if pretrained:
            self.backbone.load_state_dict(torch.load(
                '../input/swin-tiny-patch4-window7-224-pretrained/swin_tiny_patch4_window7_224_pretrained.pth'
            ))
        
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.backbone.num_features, 1),
        )
        
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(self, x):
        return self.fc(self.backbone(x))
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        images = my_transforms['train'](images)
        labels = labels / 100.0
        loss = self.criterion(self(images).squeeze(1), labels)
        return loss
    
    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            images, labels = batch
            images = my_transforms['val'](images)
            labels = labels / 100.0
            logits = self(images).squeeze(1)
            loss = self.criterion(logits, labels)
            self.log('val_loss', loss)
    
            mse_loss = torch.nn.MSELoss(reduction='sum')(torch.sigmoid(logits), labels).detach()
            return mse_loss, len(logits)
    
    def validation_epoch_end(self, outputs):
        with torch.no_grad():
            self.log('mse_loss', sum(o[0] for o in outputs) / sum(o[1] for o in outputs))
    
    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            images = my_transforms['val'](batch)
            logits = self(images)
            predictions = 100.0 * torch.sigmoid(logits).detach().squeeze()
            return predictions
    
    def test_epoch_end(self, outputs):
        with torch.no_grad():
            submission_df = pd.read_csv('../input/petfinder-pawpularity-score/sample_submission.csv')
            submission_df['Pawpularity'] = torch.cat(outputs).detach().cpu()
            submission_df.to_csv(f'submission_{self.split}.csv', index=False)
    
    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=config.lr)

### Meta Pseudo Labels

See https://arxiv.org/abs/2003.10580

In [None]:
class MetaPseudoLabelSystem(pl.LightningModule):
    def __init__(self, teacher: nn.Module, student: nn.Module, split: int):
        super().__init__()
        self.automatic_optimization = False
        
        self.teacher = teacher        
        self.student = student
        self.split = split
        
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(self, x):
        return self.student(x)
    
    def training_step(self, batch, batch_idx):
        unlabeled_images = batch['unlabeled']
        labeled_images, labels = batch['labeled']
        
        unlabeled_images = my_transforms['train'](unlabeled_images)
        labeled_images = my_transforms['train'](labeled_images)
        labels = labels / 100.0
        
        student_opt, teacher_opt = self.optimizers()
        
        teacher_preds = torch.sigmoid(self.teacher(unlabeled_images).squeeze(1))
        student_loss = self.criterion(self.student(unlabeled_images).squeeze(1), teacher_preds)
        
        student_opt.zero_grad()
        self.manual_backward(student_loss)
        student_opt.step()
        
        teacher_loss = self.criterion(self.student(labeled_images).squeeze(1), labels)
        teacher_opt.zero_grad()
        self.manual_backward(teacher_loss)
        teacher_opt.step()
        
#         print(teacher_preds)
        
#         print(student_loss.detach().item())
#         print(teacher_loss.detach().item())
        
        self.log('student_loss', student_loss)
        self.log('teacher_loss', teacher_loss)
    
    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            images, labels = batch
            images = my_transforms['val'](images)
            labels = labels / 100.0

            logits = self(images).squeeze(1)
            loss = self.criterion(logits, labels)
            self.log('val_loss', loss)

            mse_loss = torch.nn.MSELoss(reduction='sum')(logits, labels).detach()
            return mse_loss, len(logits)
    
    def validation_epoch_end(self, outputs):
        with torch.no_grad():
            self.log('mse_loss', sum(o[0] for o in outputs) / sum(o[1] for o in outputs))
    
    def configure_optimizers(self):
        return optim.AdamW(self.student.parameters(), config.lr), optim.AdamW(self.teacher.parameters(), config.lr)

## Training + CV

Note that attempting memory management with `del` and `gc.collect()` is probably bad practice.

In [None]:
pawpularity_root = '../input/petfinder-pawpularity-score/'
unlabeled_globs = [
    '../input/cat-dataset/CAT_0?/*.jpg',
    '../input/stanford-dogs-dataset/images/Images/*/*.jpg',
]

logger = CSVLogger('logs')
for split in range(config.splits):
    my_teacher = SwinModel(pretrained=True, split=split)
    teacher_pretrainer = pl.Trainer(max_epochs=config.epochs.pretrain, callbacks=[EarlyStopping(monitor='val_loss'), GPUStatsMonitor()], logger=logger, **config.trainer)
    pretrain_datamodule = CanonicalSupervisedDataModule(pawpularity_root, split)
    teacher_pretrainer.fit(my_teacher, datamodule=pretrain_datamodule)
    
    del teacher_pretrainer
    del pretrain_datamodule
    gc.collect()
    
    my_student = SwinModel(pretrained=True, split=split)
    mpsl = MetaPseudoLabelSystem(my_teacher, my_student, split)
    mpsl_datamodule = PawpularitySSLDataModule(pawpularity_root, split, unlabeled_globs)
    mpsl_trainer = pl.Trainer(max_epochs=config.epochs.meta_ps, callbacks=[EarlyStopping(monitor='val_loss'), GPUStatsMonitor()], logger=logger, **config.trainer)
    mpsl_trainer.fit(mpsl, datamodule=mpsl_datamodule)
    
    del mpsl_trainer
    del mpsl_datamodule
    del mpsl
    del my_teacher
    gc.collect()
    
    for param in my_student.backbone.parameters():
        param.requires_grad = False
    finetuner = pl.Trainer(max_epochs=config.epochs.finetune, callbacks=[EarlyStopping(monitor='val_loss'), GPUStatsMonitor()], logger=logger, **config.trainer)
    finetune_datamodule = CanonicalSupervisedDataModule(pawpularity_root, split)
    finetuner.fit(my_student, datamodule=finetune_datamodule)
    
    finetuner.test(my_student, datamodule=finetune_datamodule)

    del finetuner
    del finetune_datamodule
    del my_student
    gc.collect()

## Submit Predictions

In [None]:
test_predictions = []
for split in range(config.splits):
    df = pd.read_csv(f'submission_{split}.csv')
    test_predictions.append(df['Pawpularity'])
concat = pd.concat(test_predictions, 1)

submission_df = pd.read_csv('../input/petfinder-pawpularity-score/sample_submission.csv')
submission_df['Pawpularity'] = concat.mean(1)
submission_df.to_csv('submission.csv', index=False)
submission_df