In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
my_secret = user_secrets.get_secret("wandb key")

import wandb
wandb.login(key=my_secret)

In [None]:
import torch

import pytorch_lightning as pl
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import CIFAR10
from torchvision import transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cifar_train = CIFAR10(root='/kaggle/temp/', train=True, transform=transform, download=True)

In [None]:
# datasets contents
import numpy as np
import torch

from torch.utils.data import TensorDataset
from torchvision.datasets import CIFAR10
from torchvision.transforms.functional import rotate


def rotated_dataset(images, method='all', return_dataset=True, tensor_embedding=False):
    """ make a dataset of rotated images """
    all_images = []
    all_angles = []

    angles = [0, 90, -90, 180]
    for image in images:
        if method == 'all':
            rotations = angles
        else:
            rotations = [np.random.choice(angles)]

        for angle in rotations:
            all_images.append(rotate(image, int(angle)))
            all_angles.append(angle)

    all_images = torch.stack(all_images)
    all_angles = torch.as_tensor(all_angles)

    if tensor_embedding:
        # embed as tensors
        angles_radians = (2*np.pi / 360) * all_angles
        all_angles = torch.stack([torch.cos(angles_radians),
                                 torch.sin(angles_radians)], dim=-1)
    else:
        # embed angles as class labels
        indices = {angle: i for i, angle in enumerate(angles)}
        all_angles = torch.as_tensor(list(map(lambda x: indices[x.item()], all_angles)))

    if not return_dataset:
        return all_images, all_angles

    return TensorDataset(all_images, all_angles)

In [None]:
# TTT Model
import torch
import pytorch_lightning as pl

from torch import nn
from collections import OrderedDict
from torchvision.models import resnet50, ResNet50_Weights


class TTTModel(pl.LightningModule):
    """ Version of a backbone model that implements Test Time Training
        on a specific 
    """
    def __init__(self, branch_layer='layer2', train_mode='base_only',
                 target_embedding='angular', n_classes=10
                 ):
        super(TTTModel, self).__init__()
        self.save_hyperparameters()

        self.primary = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.secondary = resnet50(weights=ResNet50_Weights.DEFAULT)

        # setup angle decoding model by mixing the two resnets together
        module_names = list(self.primary._modules.keys())
        branch_ind = module_names.index(branch_layer)

        TTTBranch = OrderedDict([
                            (key, self.primary._modules[key]) if i <= branch_ind
                            else (key, self.secondary._modules[key])
                            for i, key in enumerate(module_names)
                            ])
        # slight quirk of the Resnet implementation: it doesn't work as a
        # simple series of modules. We need to flatten
        TTTBranch['flatten'] = nn.Flatten()
        TTTBranch.move_to_end('fc')
        self.TTTBranch = nn.Sequential(TTTBranch)

        # decoders 
        self.class_decoder = nn.Linear(1000, n_classes)

        angle_dims = (2 if target_embedding == 'angular' else 4)
        self.angle_decoder = nn.Linear(1000, angle_dims)

        # losses
        self.classification_loss = nn.CrossEntropyLoss()

        if target_embedding == 'angular':
            self.angle_loss = lambda x, y: nn.functional.cosine_similarity(x, y).mean()
        else:
            self.angle_loss = nn.CrossEntropyLoss()

        # mode for training
        self.train_mode = train_mode

    def forward(self, x):
        return self.class_decoder(self.primary(x))

    def forward_branch(self, x):
        return self.angle_decoder(self.TTTBranch(x))

    # train, val, test logic
    def training_step(self, batch, batchind=None):
        x, y = batch
        if self.train_mode == 'test_time':
            loss = self.angle_loss(self.forward_branch(x), y)

        if self.train_mode == 'base_only':
            loss = self.classification_loss(self.forward(x), y)

        if self.train_mode == 'joint':
            classification_loss = self.classification_loss(self.forward(x), y)

            rotated, angle = rotated_dataset(x,  method='sample', return_dataset=False,
                                             tensor_embedding=(
                                                self.hparams.target_embedding == 'angular'
                                                )
                                             )
            angle_loss = self.angle_loss(self.forward_branch(rotated), angle.to(self.device))

            loss = classification_loss + angle_loss

        self.log('Train loss', loss)
        return loss

    def validation_step(self, batch, batch_ind=None):
        x, y = batch
        outputs = self.forward(x)
        predictions = torch.argmax(outputs, axis=1)
        accuracy = (predictions == y).to(torch.float32).mean()
        
        loss = self.classification_loss(outputs, y)
        self.log('Val loss', loss)
        self.log('Val accuracy', accuracy)

    def test_step(self, batch, batch_ind=None):
        x, y = batch
        outputs = self.forward(x)
        predictions = torch.argmax(outputs, axis=1)
        accuracy = (predictions == y).to(torch.float32).mean()

        self.log('test accuracy', accuracy)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    # utilities for freezing different parts of the model.
    def freeze_primary(self):
        for parameter in self.primary.parameters():
            parameter.requires_grad = False

    def unfreeze_primary(self):
        for parameter in self.primary.parameters():
            parameter.requires_grad = True

    def freeze_secondary(self):
        for parameter in self.secondary.parameters():
            parameter.requires_grad = False
    
    def unfreeze_sceondary(self):
        for parameter in self.secondary.parameters():
            parameter.requires_grad = True


In [None]:
train_set, val_set = torch.utils.data.random_split(cifar_train, (0.95, 0.05))

train_dl = torch.utils.data.DataLoader(train_set, batch_size=32,
                                          shuffle=True, num_workers=4, persistent_workers=True)
val_dl = torch.utils.data.DataLoader(train_set, batch_size=32, num_workers=4, persistent_workers=True)


In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

import wandb

def train_model(model, dir_name):
    wandb.init(reinit=True, project='ttt')
    logger = WandbLogger(project='ttt')
    early_stopping = EarlyStopping(monitor="Val accuracy", mode="max", patience=5)
    checkpoint_callback = ModelCheckpoint(dirpath=f'/kaggle/working/{dir_name}', save_top_k=1,
                                          monitor="Val accuracy", mode='max')

    trainer = pl.Trainer(logger=logger, callbacks=[checkpoint_callback, early_stopping])
    trainer.fit(model, train_dl, val_dl)
    wandb.finish()
    

In [None]:
for i in range(1,5):
    model = TTTModel(branch_layer=f'layer{i}', train_mode='joint')
    train_model(model,f'layer{i}')