In [21]:
!pip install git+https://github.com/openai/CLIP.git


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-wqmglgbb
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-wqmglgbb
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [22]:
!git clone https://github.com/xk-huang/OrdinalCLIP.git

fatal: destination path 'OrdinalCLIP' already exists and is not an empty directory.


In [23]:
import sys
sys.path.append("/kaggle/working/OrdinalCLIP")


In [24]:
import os
import time
import copy
import math 
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, confusion_matrix
from PIL import Image
import matplotlib.pyplot as plt
import copy
import torch.nn.init as init
from IPython.display import FileLink
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision._internally_replaced_utils import load_state_dict_from_url
from torch import optim
import torch.utils.model_zoo as model_zoo
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import lr_scheduler
from clip import clip
from clip.model import CLIP
import logging
from scipy.ndimage import gaussian_filter1d
from scipy.signal.windows import triang
import torchvision
from torchvision import transforms, datasets, models
from typing import Optional, Callable, List, Type, Union, Any
from ordinalclip.utils import Registry

import timm
from torchsummary import summary
from tqdm import tqdm
from torchvision.models import densenet121, densenet201
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [25]:
train_img_dir = "/kaggle/input/aptos2019/train_images/train_images"
val_img_dir="/kaggle/input/aptos2019/val_images/val_images"
test_img_dir = "/kaggle/input/aptos2019/test_images/test_images"

train_csv_path = "/kaggle/input/aptos2019/train_1.csv"
val_csv_path="/kaggle/input/aptos2019/valid.csv"
test_csv_path = "/kaggle/input/aptos2019/test.csv"

CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD  = [0.26862954, 0.26130258, 0.27577711]

transform = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(degrees=30),
        transforms.ToTensor(),
        transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
    ]),
    "test": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
    ])
}
class AptosMapper(Dataset):
    def __init__(self, csv_file, img_dir, transforms=None):
        self.labels = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transforms
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = self.labels.iloc[idx, 0] + ".png"
        label = int(self.labels.iloc[idx, 1])
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

batch_size=64

train_dataset = AptosMapper(train_csv_path, train_img_dir, transforms=transform["train"])
val_dataset = AptosMapper(val_csv_path, val_img_dir, transforms=transform["test"])
test_dataset = AptosMapper(test_csv_path, test_img_dir, transforms=transform["test"])

len_trainset = len(train_dataset)
len_valset = len(val_dataset)
len_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
print(len_trainset,len_valset,len_test)


In [26]:
model, preprocess = clip.load("RN50", device=device)
model.eval()

print("CLIP loaded on:", device)

In [27]:
init_rank_path = None

class PlainPromptLearner(nn.Module):
    clip_max_num_tokens = 77
    rank_tokens_position_candidates = {"tail", "middle", "front"}

    def __init__(
        self,
        clip_model: CLIP,
        num_ranks: int,
        num_tokens_per_rank: Union[int, List],
        num_context_tokens: int,
        rank_tokens_position: str = "tail",
        init_context: Optional[str] = None,
        rank_specific_context: bool = False,
    ):
        super().__init__()

        self.num_ranks = num_ranks
        self.num_context_tokens = num_context_tokens
        self.rank_tokens_positon = rank_tokens_position

        dtype = clip_model.token_embedding.weight.dtype
        context_embeds, _num_context_tokens = self.create_context_embeds(
            clip_model, num_ranks, num_context_tokens, init_context, rank_specific_context, dtype
        )
        num_context_tokens = _num_context_tokens
        self.context_embeds = nn.Parameter(context_embeds)

        if isinstance(num_tokens_per_rank, int):
            num_tokens_per_rank = [num_tokens_per_rank] * num_ranks
        rank_embeds, _num_tokens_per_rank = self.create_rank_embeds(
            clip_model, num_ranks, num_tokens_per_rank, init_rank_path, dtype, num_context_tokens
        )
        num_tokens_per_rank = _num_tokens_per_rank
        self.rank_embeds = nn.Parameter(rank_embeds)
        assert len(rank_embeds) == num_ranks

        psudo_sentence_tokens = self.create_psudo_sentence_tokens(
            num_tokens_per_rank, num_context_tokens, num_ranks
        )
        self.register_buffer("psudo_sentence_tokens", psudo_sentence_tokens, persistent=False)

        self.num_context_tokens = num_context_tokens
        self.num_tokens_per_rank = num_tokens_per_rank
        if rank_tokens_position not in self.rank_tokens_position_candidates:
            raise ValueError(f"Invalid rank_tokens_position: {rank_tokens_position}")
        self.rank_tokens_positon = rank_tokens_position
        self.num_ranks = num_ranks
        self.embeddings_dim = clip_model.token_embedding.embedding_dim

        self.create_sentence_embeds_template(clip_model, num_ranks, psudo_sentence_tokens)

    def forward(self):
        context_embeds = self.context_embeds

        if context_embeds.dim() == 2:
            context_embeds = context_embeds[None].expand(self.num_ranks, *context_embeds.shape)

        sentence_embeds = self.sentence_embeds.clone()
        if self.rank_tokens_positon == "tail":
            for i in range(self.num_ranks):
                _num_tokens_per_rank = self.num_tokens_per_rank[i]
                pure_sentence_length = self.num_context_tokens + _num_tokens_per_rank
                sentence_embeds[i, 1:1 + pure_sentence_length] = torch.cat(
                    [context_embeds[i], self.rank_embeds[i, :_num_tokens_per_rank]], dim=0
                )
        elif self.rank_tokens_positon == "front":
            for i in range(self.num_ranks):
                _num_tokens_per_rank = self.num_tokens_per_rank[i]
                pure_sentence_length = self.num_context_tokens + _num_tokens_per_rank
                sentence_embeds[i, 1:1 + pure_sentence_length] = torch.cat(
                    [self.rank_embeds[i, :_num_tokens_per_rank], context_embeds[i]], dim=0
                )
        elif self.rank_tokens_positon == "middle":
            for i in range(self.num_ranks):
                _num_tokens_per_rank = self.num_tokens_per_rank[i]
                pure_sentence_length = self.num_context_tokens + _num_tokens_per_rank
                _context_embeds = context_embeds[i]
                half_range = self.num_context_tokens // 2
                sentence_embeds[i, 1:1 + pure_sentence_length] = torch.cat(
                    [
                        _context_embeds[:half_range],
                        self.rank_embeds[i, :_num_tokens_per_rank],
                        _context_embeds[half_range:],
                    ],
                    dim=0,
                )
        return sentence_embeds

    def create_sentence_embeds_template(self, clip_model, num_ranks, psudo_sentence_tokens):
        with torch.no_grad():
            device = clip_model.token_embedding.weight.device
            dtype = clip_model.token_embedding.weight.dtype
            
            null_embed = clip_model.token_embedding(
                torch.tensor([0], device=device)
            )[0].to(dtype)
            
            sot_embed = clip_model.token_embedding(
                torch.tensor([49406], device=device)
            )[0].to(dtype)
            
            eot_embed = clip_model.token_embedding(
                torch.tensor([49407], device=device)
            )[0].to(dtype)
            
            full_stop_embed = clip_model.token_embedding(
                torch.tensor([269], device=device)
            )[0].to(dtype)


        sentence_embeds = null_embed[None, None].repeat(
            num_ranks, self.clip_max_num_tokens, 1
        )
        argmax_index = psudo_sentence_tokens.argmax(dim=-1)
        rank_index = torch.arange(num_ranks)

        sentence_embeds[:, 0, :] = sot_embed
        sentence_embeds[rank_index, argmax_index] = eot_embed
        sentence_embeds[rank_index, argmax_index - 1] = full_stop_embed

        self.register_buffer("sentence_embeds", sentence_embeds, persistent=False)

    def create_psudo_sentence_tokens(self, num_tokens_per_rank, num_context_tokens, num_ranks):
        psudo_sentence_tokens = torch.zeros(num_ranks, self.clip_max_num_tokens, dtype=torch.long)

        if isinstance(num_tokens_per_rank, List):
            assert num_ranks == len(num_tokens_per_rank)
            for i, _num_tokens_per_rank in enumerate(num_tokens_per_rank):
                sentence_length = 1 + num_context_tokens + _num_tokens_per_rank + 1 + 1
                psudo_sentence_tokens[i, :sentence_length] = torch.arange(0, sentence_length, dtype=torch.long)
        else:
            sentence_length = 1 + num_context_tokens + num_tokens_per_rank + 1 + 1
            psudo_sentence_tokens[:, :sentence_length] = torch.arange(0, sentence_length, dtype=torch.long)
        return psudo_sentence_tokens

    def create_rank_embeds(
        self, clip_model, num_ranks, num_tokens_per_rank, init_rank_path, dtype, num_context_tokens
    ):
        if init_rank_path is not None:
            rank_names = self.read_rank_file(init_rank_path)

            if len(rank_names) != num_ranks:
                raise ValueError("rank_names length mismatch")

            _rank_tokens = [clip._tokenizer.encode(rank_name) for rank_name in rank_names]
            _num_tokens_per_rank = [len(rank_token) for rank_token in _rank_tokens]
            num_tokens_per_rank = _num_tokens_per_rank
            max_num_tokens_per_rank = np.max(num_tokens_per_rank)

            rank_tokens = torch.zeros(len(_rank_tokens), max_num_tokens_per_rank, dtype=torch.long)
            for i, rank_token in enumerate(_rank_tokens):
                valid_length = self.clip_max_num_tokens - num_context_tokens - 3
                if len(rank_token) > valid_length:
                    rank_token = rank_token[:valid_length]
                    raise ValueError("rank tokens too long")
                rank_tokens[i, :len(rank_token)] = torch.LongTensor(rank_token)

            rank_embeds = clip_model.token_embedding(rank_tokens).type(dtype)
            rank_embeds = rank_embeds[:, :max_num_tokens_per_rank]
        else:
            embeddings_dim = clip_model.token_embedding.embedding_dim
            if isinstance(num_tokens_per_rank, List):
                max_num_tokens_per_rank = np.max(num_tokens_per_rank)
            else:
                max_num_tokens_per_rank = num_tokens_per_rank
            if self.clip_max_num_tokens < num_context_tokens + max_num_tokens_per_rank + 3:
                raise ValueError("rank tokens too long")
            rank_embeds = torch.empty((num_ranks, max_num_tokens_per_rank, embeddings_dim), dtype=dtype)
            nn.init.normal_(rank_embeds, std=0.02)

        return rank_embeds, num_tokens_per_rank

    def read_rank_file(self, init_rank_path):
        rank_names = []
        with open(init_rank_path, "r") as f:
            for line in f.readlines():
                line = line.strip().replace("_", " ")
                rank_names.append(line)
        return rank_names

    def create_context_embeds(
        self,
        clip_model,
        num_ranks: int,
        num_context_tokens: int,
        init_context: Optional[str],
        rank_specific_context: bool,
        dtype,
    ):
        if init_context is not None:
            init_context = init_context.replace("_", " ")
            prompt_tokens = clip.tokenize(init_context)[0]
            _num_context_tokens = torch.argmax(prompt_tokens).item() - 1
            num_context_tokens = _num_context_tokens

            with torch.no_grad():
                context_embeds = clip_model.token_embedding(prompt_tokens).type(dtype)
            context_embeds = context_embeds[1:1 + num_context_tokens]

            if rank_specific_context:
                context_embeds = context_embeds[None].repeat(num_ranks, 1, 1)
        else:
            embeds_dim = clip_model.token_embedding.embedding_dim
            if rank_specific_context:
                context_embeds = torch.empty((num_ranks, num_context_tokens, embeds_dim), dtype=dtype)
            else:
                context_embeds = torch.empty((num_context_tokens, embeds_dim), dtype=dtype)
            nn.init.normal_(context_embeds, std=0.02)

        return context_embeds, num_context_tokens


In [28]:
def load_clip(
    text_encoder_name: str,
    image_encoder_name: str,
    device: str
):
    clip_model, _ = clip.load(text_encoder_name, device=device)
    clip_model = clip_model.float()

    if image_encoder_name != text_encoder_name:
        embed_dim = clip_model.text_projection.shape[1]
        input_resolution = clip_model.visual.input_resolution

        MODEL = getattr(models, image_encoder_name, None)
        if MODEL is None:
            raise ValueError(f"Invalid torchvision model: {image_encoder_name}")

        clip_model.visual = MODEL(num_classes=embed_dim)
        clip_model.visual.input_resolution = input_resolution

    return clip_model


In [29]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x)

        x = x[
            torch.arange(x.shape[0]),
            tokenized_prompts.argmax(dim=-1)
        ] @ self.text_projection

        return x


In [30]:
def forward_text_only(self):
        sentence_embeds = self.prompt_learner()
        psudo_sentence_tokens = self.psudo_sentence_tokens
        text_features = self.text_encoder(sentence_embeds, psudo_sentence_tokens)

        return text_features

def encode_image(self, x):
        return self.image_encoder(x)


In [31]:
class CLIPDR(nn.Module):
    def __init__(self, clip_model, prompt_learner):
        super().__init__()
        self.image_encoder = clip_model.visual
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.logit_scale = clip_model.logit_scale
        self.prompt_learner = prompt_learner
        self.psudo_sentence_tokens = prompt_learner.psudo_sentence_tokens
        self.embed_dims = clip_model.text_projection.shape[1]
        self.num_ranks = self.prompt_learner.num_ranks
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
    
    
    
    def forward(self, images):
        sentence_embeds = self.prompt_learner()
        psudo_sentence_tokens = self.psudo_sentence_tokens
        text_features = self.text_encoder(sentence_embeds, psudo_sentence_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        image_features = self.image_encoder(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        # Compute logits
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()
        
        # Return raw features too for FDS
        return logits, image_features, text_features

In [32]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

clip_model, _ = clip.load("RN50", device=device)
clip_model = clip_model.float()
print("CLIP model loaded")

prompt_learner = PlainPromptLearner(
    clip_model=clip_model,
    num_ranks=5,
    num_tokens_per_rank=1,
    num_context_tokens=10,
    rank_tokens_position="tail",
    init_context=None,
    rank_specific_context=False
).to(device)
print("Prompt learner created")

model = CLIPDR(
    clip_model=clip_model,
    prompt_learner=prompt_learner
).to(device)
print("CLIPDR model created")


In [33]:
prompt_learner = PlainPromptLearner(
    clip_model=clip_model,
    num_ranks=5,
    num_tokens_per_rank=1,
    num_context_tokens=10,
    rank_tokens_position="tail",
    init_context=None,
    rank_specific_context=False
).to(device)

model = CLIPDR(
    clip_model=clip_model,
    prompt_learner=prompt_learner
).to(device)


In [34]:
def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.5, clip_max=2.0):
    if torch.sum(v1) < 1e-10:
        return matrix

    if (v1 <= 0.).any() or (v2 < 0.).any():
        valid_pos = (((v1 > 0.) + (v2 >= 0.)) == 2)
        factor = torch.clamp(v2[valid_pos] / v1[valid_pos], clip_min, clip_max)
        matrix[:, valid_pos] = (
            matrix[:, valid_pos] - m1[valid_pos]
        ) * torch.sqrt(factor) + m2[valid_pos]
        return matrix

    factor = torch.clamp(v2 / v1, clip_min, clip_max)
    return (matrix - m1) * torch.sqrt(factor) + m2


In [35]:
print = logging.info


class FDS(nn.Module):

    def __init__(self, feature_dim, bucket_num=100, bucket_start=3, start_update=0, start_smooth=1,
                 kernel='gaussian', ks=5, sigma=2, momentum=0.9):
        super(FDS, self).__init__()
        self.feature_dim = feature_dim
        self.bucket_num = bucket_num
        self.bucket_start = bucket_start
        self.kernel_window = self._get_kernel_window(kernel, ks, sigma)
        self.half_ks = (ks - 1) // 2
        self.momentum = momentum
        self.start_update = start_update
        self.start_smooth = start_smooth

        self.register_buffer('epoch', torch.zeros(1).fill_(start_update))
        self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim))
        self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim))
        self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
        self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
        self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
        self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
        self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start))

    @staticmethod
    def _get_kernel_window(kernel, ks, sigma):
        assert kernel in ['gaussian', 'triang', 'laplace']
        half_ks = (ks - 1) // 2
        if kernel == 'gaussian':
            base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
            base_kernel = np.array(base_kernel, dtype=np.float32)
            kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma))
        elif kernel == 'triang':
            kernel_window = triang(ks) / sum(triang(ks))
        else:
            laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
            kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1)))

        print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})')
        return torch.tensor(kernel_window, dtype=torch.float32).cuda()

    def _update_last_epoch_stats(self):
        self.running_mean_last_epoch = self.running_mean
        self.running_var_last_epoch = self.running_var

        self.smoothed_mean_last_epoch = F.conv1d(
            input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
                        pad=(self.half_ks, self.half_ks), mode='reflect'),
            weight=self.kernel_window.view(1, 1, -1), padding=0
        ).permute(2, 1, 0).squeeze(1)
        self.smoothed_var_last_epoch = F.conv1d(
            input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
                        pad=(self.half_ks, self.half_ks), mode='reflect'),
            weight=self.kernel_window.view(1, 1, -1), padding=0
        ).permute(2, 1, 0).squeeze(1)

    def reset(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.running_mean_last_epoch.zero_()
        self.running_var_last_epoch.fill_(1)
        self.smoothed_mean_last_epoch.zero_()
        self.smoothed_var_last_epoch.fill_(1)
        self.num_samples_tracked.zero_()

    def update_last_epoch_stats(self, epoch):
        if epoch == self.epoch + 1:
            self.epoch += 1
            self._update_last_epoch_stats()
            print(f"Updated smoothed statistics on Epoch [{epoch}]!")

    def update_running_stats(self, features, labels, epoch):
        #if epoch < self.epoch:
        #    return

        assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!"
        assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!"

        for label in torch.unique(labels):
            if label > self.bucket_num - 1 or label < self.bucket_start:
                continue
            elif label == self.bucket_start:
                curr_feats = features[labels <= label]
            elif label == self.bucket_num - 1:
                curr_feats = features[labels >= label]
            else:
                curr_feats = features[labels == label]
            curr_num_sample = curr_feats.size(0)
            curr_mean = torch.mean(curr_feats, 0)
            curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False)

            self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample
            factor = self.momentum if self.momentum is not None else \
                (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)]))
            factor = 0 if epoch == self.start_update else factor
            self.running_mean[int(label - self.bucket_start)] = \
                (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)]
            self.running_var[int(label - self.bucket_start)] = \
                (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)]

        print(f"Updated running statistics with Epoch [{epoch}] features!")

    def smooth(self, features, labels, epoch):
        #if epoch < self.start_smooth:
            #return features

        #labels = labels.squeeze(1)
        for label in torch.unique(labels):
            if label > self.bucket_num - 1 or label < self.bucket_start:
                continue
            elif label == self.bucket_start:
                features[labels <= label] = calibrate_mean_var(
                    features[labels <= label],
                    self.running_mean_last_epoch[int(label - self.bucket_start)],
                    self.running_var_last_epoch[int(label - self.bucket_start)],
                    self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
                    self.smoothed_var_last_epoch[int(label - self.bucket_start)])
            elif label == self.bucket_num - 1:
                features[labels >= label] = calibrate_mean_var(
                    features[labels >= label],
                    self.running_mean_last_epoch[int(label - self.bucket_start)],
                    self.running_var_last_epoch[int(label - self.bucket_start)],
                    self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
                    self.smoothed_var_last_epoch[int(label - self.bucket_start)])
            else:
                features[labels == label] = calibrate_mean_var(
                    features[labels == label],
                    self.running_mean_last_epoch[int(label - self.bucket_start)],
                    self.running_var_last_epoch[int(label - self.bucket_start)],
                    self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
                    self.smoothed_var_last_epoch[int(label - self.bucket_start)])
        return features

In [36]:
import math

import torch
from torch.optim.optimizer import Optimizer, required


class RAdam(Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        degenerated_to_sgd=False,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
                    param["buffer"] = [[None, None, None] for _ in range(10)]
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            buffer=[[None, None, None] for _ in range(10)],
        )
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError("RAdam does not support sparse gradients")

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p_data_fp32)
                    state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
                else:
                    state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
                    state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state["step"] += 1
                buffered = group["buffer"][int(state["step"] % 10)]
                if state["step"] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state["step"]
                    beta2_t = beta2 ** state["step"]
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt(
                            (1 - beta2_t)
                            * (N_sma - 4)
                            / (N_sma_max - 4)
                            * (N_sma - 2)
                            / N_sma
                            * N_sma_max
                            / (N_sma_max - 2)
                        ) / (1 - beta1 ** state["step"])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state["step"])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group["weight_decay"] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group["weight_decay"] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
                    p.data.copy_(p_data_fp32)

        return loss


class PlainRAdam(Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        degenerated_to_sgd=False,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

        self.degenerated_to_sgd = degenerated_to_sgd
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

        super(PlainRAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(PlainRAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError("RAdam does not support sparse gradients")

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p_data_fp32)
                    state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
                else:
                    state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
                    state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state["step"] += 1
                beta2_t = beta2 ** state["step"]
                N_sma_max = 2 / (1 - beta2) - 1
                N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group["weight_decay"] != 0:
                        p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
                    step_size = (
                        group["lr"]
                        * math.sqrt(
                            (1 - beta2_t)
                            * (N_sma - 4)
                            / (N_sma_max - 4)
                            * (N_sma - 2)
                            / N_sma
                            * N_sma_max
                            / (N_sma_max - 2)
                        )
                        / (1 - beta1 ** state["step"])
                    )
                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif self.degenerated_to_sgd:
                    if group["weight_decay"] != 0:
                        p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
                    step_size = group["lr"] / (1 - beta1 ** state["step"])
                    p_data_fp32.add_(-step_size, exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss


class AdamW(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup=warmup)
        super(AdamW, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AdamW, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p_data_fp32)
                    state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
                else:
                    state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
                    state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                denom = exp_avg_sq.sqrt().add_(group["eps"])
                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]

                if group["warmup"] > state["step"]:
                    scheduled_lr = 1e-8 + state["step"] * group["lr"] / group["warmup"]
                else:
                    scheduled_lr = group["lr"]

                step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1

                if group["weight_decay"] != 0:
                    p_data_fp32.add_(-group["weight_decay"] * scheduled_lr, p_data_fp32)

                p_data_fp32.addcdiv_(-step_size, exp_avg, denom)

                p.data.copy_(p_data_fp32)

        return loss

In [37]:
optimizer = RAdam(
    model.parameters(),
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0
)

In [38]:
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[60],
    gamma=0.1
)


In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, roc_curve, auc

# REQUIRED imports
# make sure these exist in your project
# from optim import RAdam
# from fds import FDS


class Runner(pl.LightningModule):
    def __init__(self, model, num_ranks=5):
        super().__init__()
        self.model = model
        self.num_ranks = num_ranks
        
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction="sum")
        
        self.FDS = FDS(
            feature_dim=5,
            bucket_num=100,
            bucket_start=3,
            start_update=0,
            start_smooth=1,
            kernel='gaussian',
            ks=5,
            sigma=2,
            momentum=0.9
        )
        
        self.register_buffer(
            "rank_output_value_array",
            torch.arange(0, num_ranks).float(),
            persistent=False
        )
    
    def forward(self, x):
        return self.model(x)

    def forward_text_only(self):
        return self.forward_text_only()
    
    def run_step(self, batch, batch_idx, M):
        x, y = batch
        
        our_logits, image_features, text_features = self.model(x)
        our_logits = our_logits.float()
        
        if M == 0:
            rank_loss = self.rank_loss(our_logits, y)
            loss_kl = self.compute_kl_loss(our_logits, y)
            loss_ce = self.ce_loss(our_logits, y)
            loss = loss_ce + loss_kl + rank_loss
        else:
            loss_kl = self.compute_kl_loss(our_logits, y)
            loss_ce = self.ce_loss(our_logits, y)
            loss = loss_ce + loss_kl
        
        metrics_exp = self.compute_per_example_metrics(our_logits, y, "exp")
        metrics_max = self.compute_per_example_metrics(our_logits, y, "max")
        
        return {"loss": loss, **metrics_exp, **metrics_max}
    
    def training_step(self, batch, batch_idx):
        outputs = self.run_step(batch, batch_idx, M=0)
        self.logging(outputs, "train", on_step=True, on_epoch=True)
        return outputs
    
    def validation_step(self, batch, batch_idx):
        outputs = self.run_step(batch, batch_idx, M=1)
        self.logging(outputs, "val", on_step=False, on_epoch=True)
        return outputs
    
    def test_step(self, batch, batch_idx):
        outputs = self.run_step(batch, batch_idx, M=0)
        self.logging(outputs, "test", on_step=False, on_epoch=True)
        return outputs
    
    def compute_kl_loss(self, logits, y):
        y_t = F.one_hot(y, self.num_ranks).t()
        y_t_row_ind = y_t.sum(-1) > 0
        num_slots = y_t_row_ind.sum()
        y_t_reduction = (y_t * 10.0).softmax(-1)
        y_t_reduction[y_t_row_ind <= 0] = 0
        logits_t = logits.t()
        return self.kl_loss(F.log_softmax(logits_t, dim=-1), y_t_reduction) / num_slots
    
    def rank_loss(self, our_logits, y):
        indexA = torch.nonzero(y == 0, as_tuple=True)[0]
        indexB = torch.nonzero(y == 1, as_tuple=True)[0]
        indexC = torch.nonzero(y == 2, as_tuple=True)[0]
        indexD = torch.nonzero(y == 3, as_tuple=True)[0]
        indexF = torch.nonzero(y == 4, as_tuple=True)[0]
        
        images_similarity1 = torch.zeros(len(y), 5, device=our_logits.device)
        images_similarity2 = torch.zeros(len(y), 5, device=our_logits.device)
        images_similarity3 = torch.zeros(len(y), 5, device=our_logits.device)
        
        logits_similarity_image1 = torch.zeros_like(images_similarity1)
        logits_similarity_image2 = torch.zeros_like(images_similarity2)
        logits_similarity_image3 = torch.zeros_like(images_similarity3)
        
        for index in indexA:
            images_similarity1[index, 0] = 1
            logits_similarity_image1[index, :2] = our_logits[index, :2]
            logits_similarity_image2[index, 1:3] = our_logits[index, 1:3]
            logits_similarity_image3[index, 2:4] = our_logits[index, 2:4]
        
        for index in indexB:
            images_similarity1[index, 1] = 1
            logits_similarity_image1[index, 1:3] = our_logits[index, 1:3]
            logits_similarity_image2[index, 2:4] = our_logits[index, 2:4]
            logits_similarity_image3[index, 3:5] = our_logits[index, 3:5]
        
        for index in indexC:
            images_similarity1[index, 2] = 1
            logits_similarity_image1[index, 2:4] = our_logits[index, 2:4]
            logits_similarity_image2[index, 3:5] = our_logits[index, 3:5]
        
        for index in indexD:
            images_similarity1[index, 3] = 1
            logits_similarity_image1[index, 3:5] = our_logits[index, 3:5]
        
        for index in indexF:
            images_similarity1[index, 4] = 1
            logits_similarity_image1[index, 4] = our_logits[index, 4]
        
        loss1 = nn.CrossEntropyLoss()(logits_similarity_image1, images_similarity1)
        loss2 = nn.CrossEntropyLoss()(logits_similarity_image2, images_similarity2)
        loss3 = nn.CrossEntropyLoss()(logits_similarity_image3, images_similarity3)
        
        return loss1 + loss2 + loss3
    
    def compute_per_example_metrics(self, logits, y, gather_type="exp"):
        probs = F.softmax(logits, -1)
        dtype = logits.dtype
        
        if gather_type == "exp":
            predict_y = torch.sum(
                probs * self.rank_output_value_array.type(dtype),
                dim=-1
            )
        else:
            predict_y = torch.argmax(probs, dim=-1).type(dtype)
        
        mae = torch.abs(predict_y - y)
        acc = (torch.round(predict_y) == y).type(dtype)
        
        auc_ovo = roc_auc_score(
            y.cpu().numpy(),
            probs.detach().cpu().numpy(),
            average='macro',
            multi_class='ovo',
            labels=[0, 1, 2, 3, 4]
        )
        auc_ovo = torch.tensor(auc_ovo)
        
        f1 = f1_score(
            y.cpu().numpy(),
            torch.round(predict_y).detach().cpu().numpy(),  # FIXED: Added .detach()
            average='macro'
        )
        f1 = torch.tensor(f1)
        
        return {
            f"mae_{gather_type}_metric": mae,
            f"acc_{gather_type}_metric": acc,
            f"{gather_type}_DGDR_auc_metric": auc_ovo,
            f"{gather_type}_DGDR_f1_metric": f1
        }
    
    def logging(self, outputs, run_type, on_step=True, on_epoch=True):
        for k, v in outputs.items():
            if k.endswith("metric") or k.endswith("loss"):
                self.log(
                    f"{run_type}_{k}",
                    v.mean(),
                    on_step=on_step,
                    on_epoch=on_epoch,
                    prog_bar=True,
                    logger=True
                )
    
    def configure_optimizers(self):
        params = [
            {"params": self.model.prompt_learner.context_embeds, "lr": 1e-4},
            {"params": self.model.prompt_learner.rank_embeds, "lr": 1e-4},
            {"params": self.model.image_encoder.parameters(), "lr": 1e-4},
        ]
        
        optimizer = RAdam(
            params,
            lr=1e-4,
            betas=(0.9, 0.999),
            weight_decay=0,
            degenerated_to_sgd=False
        )
        
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[60],
            gamma=0.1
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1
            }
        }

In [None]:
# ============= TRAINING CODE =============
# Create checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_acc_exp_metric',  # Metric to monitor
    dirpath='checkpoints/',         # Where to save
    filename='best-model-{epoch:02d}-{val_acc_exp_metric:.2f}',
    save_top_k=1,                   # Save only the best model
    mode='max',                     # 'max' for accuracy, 'min' for loss
    save_last=True                  # Also save the last checkpoint
)

# Create runner and trainer
runner = Runner(model)
trainer = pl.Trainer(
    max_epochs=100, 
    accelerator='gpu', 
    devices=1,
    callbacks=[checkpoint_callback]
)

# Train the model
trainer.fit(runner, train_loader, val_loader)

print("Training completed! Best model saved in checkpoints/ directory")


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [None]:
# ============= TESTING CODE =============

# Find the best checkpoint
import glob
checkpoint_files = glob.glob('checkpoints/best-model-*.ckpt')
best_ckpt = checkpoint_files[0] if checkpoint_files else 'checkpoints/last.ckpt'

print(f"\nLoading best model from: {best_ckpt}")

# Create new runner for testing
test_runner = Runner(model)

# Create test trainer
test_trainer = pl.Trainer(
    accelerator='gpu', 
    devices=1
)

# Test the model
print("\nTesting the best model...")
test_results = test_trainer.test(test_runner, test_loader, ckpt_path=best_ckpt)

print("\nTest Results:")
print(test_results)