In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from nn_core.common import PROJECT_ROOT
import random

from pathlib import Path

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from pytorch_lightning import seed_everything
import matplotlib.pyplot as plt
import random
from collections import namedtuple
import timm
from transformers import AutoModel, AutoProcessor
from typing import Sequence, List
from PIL.Image import Image
from tqdm import tqdm
import functools
from timm.data import resolve_data_config
from datasets import load_dataset, load_from_disk, Dataset, DatasetDict
import torchvision
import torch
from timm.data import create_transform

In [None]:
data: MyDatasetDict = MyDatasetDict.load_from_disk("../data/cifar100/partitioned")

num_tasks = data["metadata"]["num_tasks"]

for task_ind in range(num_tasks + 1):
    data[f"task_{task_ind}_anchors"] = data["anchors"]

In [None]:
transform_func = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
for mode in {"train", "test", "anchors"}:
    for task_ind in range(num_tasks + 1):
        data[f"task_{task_ind}_{mode}"] = data[f"task_{task_ind}_{mode}"].map(
            lambda x: {"x": transform_func(x["x"])}, batched=False
        )
        data[f"task_{task_ind}_{mode}"].set_format("torch", columns=["x", "y"])

In [None]:
import logging
from typing import Any, Sequence, Tuple, Union, Dict

import hydra
import torch
from hydra.utils import instantiate
from nn_core.model_logging import NNLogger
from omegaconf import DictConfig
from torch.optim import Optimizer
import pytorch_lightning as pl

from la.pl_modules.pl_module import MyLightningModule

pylogger = logging.getLogger(__name__)


class ResNet(MyLightningModule):
    logger: NNLogger

    def __init__(self, num_classes, input_dim, model, *args, **kwargs) -> None:
        super().__init__(num_classes=num_classes, input_dim=input_dim, *args, **kwargs)

        self.save_hyperparameters(logger=False, ignore=("metadata",))

        self.model = model

    def forward(self, x: torch.Tensor) -> Dict:
        """Method for the forward pass.

        'training_step', 'validation_step' and 'test_step' should call
        this method in order to compute the output predictions and the loss.

        Returns:
            output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
        """
        model_out = self.model(x)

        return model_out

    def configure_optimizers(
        self,
    ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.

        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Return:
            Any of these 6 options.
            - Single optimizer.
            - List or Tuple - List of optimizers.
            - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
            - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'
              key whose value is a single LR scheduler or lr_dict.
            - Tuple of dictionaries as described, with an optional 'frequency' key.
            - None - Fit will run without any optimizer.
        """
        optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import DataLoader

from torch.autograd import Variable

__all__ = ["ResNet", "resnet20", "resnet32", "resnet44", "resnet56", "resnet110", "resnet1202"]


def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option="A"):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == "A":
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)
                )
            elif option == "B":
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion * planes),
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return out


class ResNetModule(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNetModule, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = F.relu(out)
        out = self.layer2(out)
        out = F.relu(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        embeds = out.view(out.size(0), -1)
        out = F.relu(embeds)

        out = self.linear(embeds)
        return {"embeds": embeds, "logits": out}


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])


def resnet32():
    return ResNet(BasicBlock, [5, 5, 5])


def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])


def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])


def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])


def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])

In [None]:
def train_model(train_data, test_data):
    model = ResNet(
        num_classes=100, model=ResNetModule(BasicBlock, [3, 3, 3], num_classes=100), input_dim=32, transform_func=None
    )

    model.configure_optimizers()

    dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=1,
        precision=32,
    )

    trainer.fit(model, dataloader)

    test_dataloader = DataLoader(test_data, batch_size=128, shuffle=False)

    trainer.test(model, test_dataloader)

    model = model.cuda()
    test_embeds = []
    for batch in test_dataloader:
        x, y = batch["x"], batch["y"]
        x = x.cuda()
        y = y.cuda()

        embeds = model(x)["embeds"]
        test_embeds.append(embeds.detach().cpu().numpy())

    return model, test_embeds

In [None]:
model1, test_embeds1 = train_model(data["task_1_train"], data["task_1_test"])

In [None]:
model2 = train_model(data["task_2_train"], data["task_2_test"])