In [1]:
pip install wandb --upgrade

In [2]:
pip install datasets

In [3]:
import copy
import math
import multiprocessing
import random
import numpy as np
from tqdm.auto import tqdm
import wandb
import torch
import torchvision
from torch.utils.data import Dataset
from datasets import load_dataset
from torchvision import models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, STL10, MNIST
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# from torchinfo import summary
import os

import warnings

warnings.filterwarnings("ignore")

In [4]:
# mrl huggingfac dataset
class HuggingfaceDataset(Dataset):

    def __init__(self, data, transform, image_key="image", label_key="label"):
        super(HuggingfaceDataset, self).__init__()
        self.data = data
        self.transform = transform
        self.image_key = image_key
        self.label_key = label_key

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx][self.image_key], int(
            self.data[idx][self.label_key])

        image = self.transform(image)
        return image, label


In [5]:
# mrl customdataset

def tiny_imagenet(transform, split="train"):
    data = load_dataset('Maysee/tiny-imagenet', split=split)
    return HuggingfaceDataset(data, transform)


def food101(transform, split="train"):
    data = load_dataset('food101', split=split)
    return HuggingfaceDataset(data, transform)


def imagenet1k(transform, split="train"):
    data = load_dataset("imagenet-1k", split=split)
    return HuggingfaceDataset(data, transform)


In [6]:
# mrl encoders

def _adapt_resnet_model(model):
    """
    Modifies some layers to handle the smaller CIFAR images, following
    the SimCLR paper. Specifically, replaces the first conv layer with
    a smaller 3x3 kernel and 1x1 strides and removes the max pooling layer.
    """
    conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    nn.init.kaiming_normal_(conv1.weight, mode="fan_out", nonlinearity="relu")
    model.conv1 = conv1
    model.maxpool = nn.Identity()
    return model


class Squeeze(nn.Module):

    def forward(self, x):
        return x.squeeze(-1).squeeze(-1)


def _prep_encoder(model):
    modules = list(model.children())[:-1]
    modules.append(nn.AdaptiveAvgPool2d(1))
    modules.append(Squeeze())

    return nn.Sequential(*modules)


def resnet18(modify_model=False):
    resnet = models.resnet18(weights=None)
    if modify_model:
        resnet = _adapt_resnet_model(resnet)
    return _prep_encoder(resnet)


def resnet50(modify_model=False):
    resnet = models.resnet50(weights=None)
    if modify_model:
        resnet = _adapt_resnet_model(resnet)
    return _prep_encoder(resnet)


In [7]:
# mrl aug

np.random.seed(42)

class ViewGenerator(object):

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]


def _grayscale_to_rgb(img):
    if img.mode == "L" or img.mode != "RGB":
        return img.convert("RGB")
    return img


def _round_up_to_odd(num):
    return np.ceil(num) // 2 * 2 + 1


def get_training_transforms(image_size):
    color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
    return transforms.Compose([
        transforms.Lambda(_grayscale_to_rgb),
        transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([color_jitter], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.Lambda(_grayscale_to_rgb),
        transforms.GaussianBlur(_round_up_to_odd(int(image_size * 0.1))),
        transforms.RandomSolarize(127, 0.5),
        transforms.ToTensor()
    ])


def get_inference_transforms(image_size=(96, 96)):
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.Lambda(_grayscale_to_rgb),
        transforms.ToTensor()
    ])


In [8]:
# mrl utils

def get_dataset(args):
    dataset_name, dataset_path = args.dataset_name, args.dataset_path
    if dataset_name == "cifar10":
        return CIFAR10(dataset_path,
                       train=True,
                       download=True,
                       transform=ViewGenerator(get_training_transforms(32), 2))
    elif dataset_name == "stl10":
        return STL10(dataset_path,
                     split='unlabeled',
                     download=True,
                     transform=ViewGenerator(get_training_transforms(96), 2))
    elif dataset_name == "mnist":
        return MNIST(dataset_path,
                             train=True,
                             download=True,
                             transform=ViewGenerator(get_training_transforms(28), 2))
    elif dataset_name == "tiny_imagenet":
        return tiny_imagenet(transform=ViewGenerator(get_training_transforms(64), 2))
    elif dataset_name == "food101":
        return food101(transform=ViewGenerator(get_training_transforms(192), 2))
    elif dataset_name == "imagenet1k":
        return imagenet1k(transform=ViewGenerator(get_training_transforms(192), 2))

    raise Exception("Invalid dataset name - options are [cifar10, stl10]")

def get_encoder(model_name, modify_model=False):
    if model_name == "resnet18":
        return resnet18(modify_model)
    elif model_name == "resnet50":
        return resnet50(modify_model)
    raise Exception(
        "Invalid model name - options are [resnet18, resnet50]")


def accuracy(output, target, topk=(1, )):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


@torch.inference_mode
def get_feature_size(encoder):
    """Get the feature size from the encoder using a dummy input."""
    encoder.eval()
    dummy_input = torch.randn(1, 3, 32, 32)
    output = encoder(dummy_input)
    return output.shape[1]


In [9]:

def logistic_regression(embeddings, labels, embeddings_val, labels_val):
    X_train, X_test = embeddings, embeddings_val
    y_train, y_test = labels, labels_val

    clf = LogisticRegression(max_iter=100)
    clf = CalibratedClassifierCV(clf)

    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_test)

    acc = accuracy_score(y_test, y_pred)
    print("Accuracy STL10: ", acc)


class STL10Eval:

    def __init__(self, image_size=96):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        transform = get_inference_transforms(image_size=(image_size, image_size))
        train_ds = torchvision.datasets.STL10("data/",
                                        split='train',
                                        transform=transform,
                                        download=True)
        val_ds = torchvision.datasets.STL10("data/",
                                        split='test',
                                        transform=transform,
                                        download=True)


        self.train_loader = DataLoader(train_ds,
                                batch_size=64,
                                num_workers=4)
        self.val_loader = DataLoader(val_ds,
                            batch_size=64,
                            num_workers=4)

    @torch.inference_mode
    def evaluate(self, relic_model):
        relic_model.eval()
        model = relic_model.target_encoder[0]
        with torch.no_grad():
            embeddings, labels = self._get_image_embs_labels(model, self.train_loader)
            embeddings_val, labels_val = self._get_image_embs_labels(model, self.val_loader)

            logistic_regression(embeddings, labels, embeddings_val, labels_val)

    def _get_image_embs_labels(self, model, dataloader):
        embs, labels = [], []
        for _, (images, targets) in enumerate(dataloader):
            with torch.no_grad():
                images = images.to(self.device)
                out = model(images)
                features = out.cpu().detach().tolist()
                embs.extend(features)
                labels.extend(targets.cpu().detach().tolist())
        return np.array(embs), np.array(labels)


In [10]:
# mrl

class MatryoshkaProjector(nn.Module):
    def __init__(self, nesting_dims, out_dims, **kwargs):
        super(MatryoshkaProjector, self).__init__()
        self.nesting_dims = nesting_dims
        
        if nesting_dims:
            self.proj_hidden = nn.Linear(nesting_dims[-1], nesting_dims[-1], **kwargs)
        else:
            raise ValueError("Error in MatryoshkaProjector: 'nesting_dims' is empty. Please ensure 'nesting_dims' is populated correctly.")

        self.relu = nn.ReLU()
        self.proj_linear = nn.Linear(nesting_dims[-1], out_dims, **kwargs)

    def _apply_linear_layer(self, x, layer, nesting_dim):
        logits = torch.matmul(x[:, :nesting_dim], layer.weight[:, :nesting_dim].t())
        if layer.bias is not None:
            logits += layer.bias
        return logits

    def forward(self, x):
        nesting_logits = []
        for nesting_dim in self.nesting_dims:
            logits = self._apply_linear_layer(x, self.proj_hidden, nesting_dim)
            logits = self.relu(logits)
            logits = self._apply_linear_layer(x, self.proj_linear, nesting_dim)
            nesting_logits.append(logits)
        return nesting_logits


def relic_loss(x, x_prime, temp, alpha, max_tau=5.0):
    """
    Parameters:
    x (torch.Tensor): Online projections [n, dim].
    x_prime (torch.Tensor): Target projections of shape [n, dim].
    temp (torch.Tensor): Learnable temperature parameter.
    alpha (float): KL divergence (regularization term) weight.
    """
    n = x.size(0)
    x, x_prime = F.normalize(x, p=2, dim=-1), F.normalize(x_prime, p=2, dim=-1)
    logits = torch.mm(x, x_prime.t()) * temp.exp().clamp(0, max_tau)

    # Instance discrimination loss
    labels = torch.arange(n).to(logits.device)
    loss = torch.nn.functional.cross_entropy(logits, labels)

    # KL divergence loss
    p1 = torch.nn.functional.log_softmax(logits, dim=1)
    p2 = torch.nn.functional.softmax(logits, dim=0).t()
    invariance_loss = torch.nn.functional.kl_div(p1, p2, reduction="batchmean")

    loss = loss + alpha * invariance_loss

    return loss


class ReLIC(torch.nn.Module):
    def __init__(self,
                 encoder,
                 proj_out_dim=64,
                 nesting_dims=None,
                 proj_in_dim=1,
                 matryoshka_bias=False):
        super(ReLIC, self).__init__()

        # Determine proj_in_dim if not provided
        if not proj_in_dim:
            proj_in_dim = get_feature_size(encoder)
            print(f"Debug: Calculated proj_in_dim = {proj_in_dim}")
            if proj_in_dim <= 0:
                print("Warning: proj_in_dim is invalid (<= 0). Setting default proj_in_dim = 64.")
                proj_in_dim = 64
            print(f"Calculated proj_in_dim: {proj_in_dim}")

        # Set proj_in_dim to a default if it’s unexpectedly low (for testing)
        if proj_in_dim <= 0:
            print("Warning: proj_in_dim is zero or negative. Setting default proj_in_dim=64 for testing.")
            proj_in_dim = 64

        # Populate nesting_dims if not provided or if empty
        if not nesting_dims:
            nesting_dims = [2 ** i for i in range(3, int(math.log2(proj_in_dim)) + 1)]
            if not nesting_dims:
                print("Warning: 'nesting_dims' is empty after calculation. Setting default nesting_dims = [8, 16, 32, 64].")
                nesting_dims = [8, 16, 32, 64]
            if not nesting_dims:
                raise ValueError("Error in ReLIC: Calculated 'nesting_dims' is empty, check 'proj_in_dim'.")

        print(f"Final nesting_dims: {nesting_dims}")  # Debug print to verify nesting_dims

        # Initialize MatryoshkaProjector with verified nesting_dims
        proj = MatryoshkaProjector(nesting_dims, proj_out_dim, bias=matryoshka_bias)

        self.online_encoder = torch.nn.Sequential(encoder, proj)
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.target_encoder.requires_grad_(False)
        self.t_prime = nn.Parameter(torch.zeros(1))


    @torch.inference_mode()
    def get_features(self, img):
        with torch.no_grad():
            return self.target_encoder[0](img)

    def forward(self, x1, x2):
        o1, o2 = self.online_encoder(x1), self.online_encoder(x2)
        with torch.no_grad():
            t1, t2 = self.target_encoder(x1), self.target_encoder(x2)
        t1 = [t_.detach() for t_ in t1]
        t2 = [t_.detach() for t_ in t2]
        return o1, o2, t1, t2

    @torch.inference_mode()
    def get_target_pred(self, x):
        with torch.no_grad():
            t = self.target_encoder(x)
        t = [t_.detach() for t_ in t]
        return t

    def get_online_pred(self, x):
        return self.online_encoder(x)

    def update_params(self, gamma):
        with torch.no_grad():
            valid_types = [torch.float, torch.float16]
            for o_param, t_param in self._get_params():
                if o_param.dtype in valid_types and t_param.dtype in valid_types:
                    t_param.data.lerp_(o_param.data, 1. - gamma)

            for o_buffer, t_buffer in self._get_buffers():
                if o_buffer.dtype in valid_types and t_buffer.dtype in valid_types:
                    t_buffer.data.lerp_(o_buffer.data, 1. - gamma)

    def copy_params(self):
        for o_param, t_param in self._get_params():
            t_param.data.copy_(o_param)

        for o_buffer, t_buffer in self._get_buffers():
            t_buffer.data.copy_(o_buffer)

    def save_encoder(self, path):
        torch.save(self.target_encoder[0].state_dict(), path)

    def _get_params(self):
        return zip(self.online_encoder.parameters(),
                   self.target_encoder.parameters())

    def _get_buffers(self):
        return zip(self.online_encoder.buffers(),
                   self.target_encoder.buffers())


In [11]:
# pip install torchinfo

In [12]:
# train

SEED = 42
MAX_TAU = 5.0

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


# cosine EMA schedule (increase from tau_base to one) as defined in https://arxiv.org/abs/2010.07922
# k -> current training step, K -> maximum number of training steps
def update_gamma(k, K, tau_base):
    k = torch.tensor(k, dtype=torch.float32)
    K = torch.tensor(K, dtype=torch.float32)

    tau = 1 - (1 - tau_base) * (torch.cos(torch.pi * k / K) + 1) / 2
    return tau.item()


def train_relic(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    modify_model = True if "cifar" in args.dataset_name else False
    encoder = get_encoder(args.encoder_model_name, modify_model)
    relic_model = ReLIC(encoder,
                        proj_out_dim=args.proj_out_dim)

    if args.ckpt_path:
        model_state = torch.load(args.ckpt_path)
        relic_model.load_state_dict(model_state)
    relic_model = relic_model.to(device)

    # summary(relic_model, input_size=[(1, 3, 32, 32), (1, 3, 32, 32)])

    params = list(relic_model.online_encoder.parameters()) + [relic_model.t_prime]
    optimizer = torch.optim.Adam(params,
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    ds = get_dataset(args)
    num_workers_upd = max(0, args.num_workers)
    train_loader = DataLoader(ds,
                              batch_size=args.batch_size,
                              num_workers=num_workers_upd,
                              drop_last=True,
                              pin_memory=True,
                              shuffle=True)

    scaler = GradScaler(enabled=args.fp16_precision)

    stl10_eval = STL10Eval()
    total_num_steps = (len(train_loader) *
                       (args.num_epochs + 2)) - args.update_gamma_after_step
    gamma = args.gamma
    global_step = 0
    total_loss = 0.0
    for epoch in range(args.num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(train_loader,
                            desc=f"Epoch {epoch+1}/{args.num_epochs}")

        for step, (images, _) in enumerate(progress_bar):
            x1, x2 = images
            x1 = x1.to(device)
            x2 = x2.to(device)

            with autocast(enabled=args.fp16_precision):
                o1, o2, t1, t2 = relic_model(x1, x2)
                losses1 = [relic_loss(o,t,relic_model.t_prime,
                                      args.alpha, max_tau=MAX_TAU)
                                    for o, t in list(zip(o1, t2))]
                loss1 = torch.stack(losses1).sum()
                losses2 = [relic_loss(o,t,relic_model.t_prime,
                                      args.alpha, max_tau=MAX_TAU)
                                    for o, t in list(zip(o2, t1))]
                loss2 = torch.stack(losses2).sum()
                loss = (loss1 + loss2) / 2

            # wandb.log({"loss": loss.item()})

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if global_step > args.update_gamma_after_step and global_step % args.update_gamma_every_n_steps == 0:
                relic_model.update_params(gamma)
                gamma = update_gamma(global_step, total_num_steps, args.gamma)

            if global_step <= args.update_gamma_after_step:
                relic_model.copy_params()

            total_loss += loss.item()
            epoch_loss += loss.item()
            avg_loss = total_loss / (global_step + 1)
            ep_loss = epoch_loss / (step + 1)

            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_description(
                f"Epoch {epoch+1}/{args.num_epochs} | "
                f"Step {global_step+1} | "
                f"Epoch Loss: {ep_loss:.4f} |"
                f"Total Loss: {avg_loss:.4f} |"
                f"Gamma: {gamma:.6f} |"
                f"Alpha: {args.alpha:.3f} |"
                f"Temp: {relic_model.t_prime.exp().item():.3f} |"
                f"Lr: {current_lr:.6f}")

            global_step += 1
            if global_step % args.log_every_n_steps == 0:
                with torch.no_grad():
                    x, x_prime = o1[-1], t2[-1]
                    x, x_prime = F.normalize(x, p=2, dim=-1), F.normalize(x_prime, p=2, dim=-1)
                    logits = torch.mm(x, x_prime.t()) * relic_model.t_prime.exp().clamp(0, MAX_TAU)
                labels = torch.arange(logits.size(0)).to(logits.device)
                top1, top5 = accuracy(logits, labels, topk=(1, 5))
                accuracy1 = top1[0].item()
                print("#" * 100)
                print('acc/top1 logits1', top1[0].item())
                print('acc/top5 logits1', top5[0].item())
                print("#" * 100)
                wandb.log({"accuracy": accuracy1})

                torch.save(relic_model.state_dict(),
                           "relic_model.pth")
                relic_model.save_encoder("encoder.pth")

            if global_step % (args.log_every_n_steps * 5) == 0:
                stl10_eval.evaluate(relic_model)
                print("!" * 100)

            wandb.log({"loss": loss.item()})


In [None]:
# # run_training

class Args:
    def __init__(self):
        self.dataset_path = './data'
        # choices=['stl10', 'cifar10', "tiny_imagenet", "food101", "imagenet1k"])
        self.dataset_name = 'cifar10'
        # choices=['resnet18', 'resnet50']
        self.encoder_model_name = 'resnet18'
        self.save_model_dir = './models'
        self.num_epochs = 2
        self.batch_size = 64
        self.learning_rate = 1e-2
        self.weight_decay = 1e-4
        self.fp16_precision = False  # Assuming default is not using 16-bit precision
        self.proj_out_dim = 64
        self.proj_hidden_dim = 512
        self.log_every_n_steps = 400
        self.gamma = 0.995
        self.alpha = 0.5
        self.update_gamma_after_step = 1
        self.update_gamma_every_n_steps = 1
        self.ckpt_path = None


args = Args()

sweep_config = {
    'method': 'random',
    'name': 'MRLsweep',  # Name of the sweep
    'metric': {
      'name': 'loss',
      'goal': 'minimize'
    },
    'parameters': {
        'dataset_name': {
            'values': ['cifar10']
            # 'values': ['stl10', 'cifar10', 'tiny_imagenet', 'food101', 'imagenet1k', 'mnist']
        },
        'encoder_model_name': {
            'values': ['resnet18', 'resnet50']
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'learning_rate': {
            'values': [1e-2, 1e-3, 1e-4]
        },
        'weight_decay': {
            'values': [1e-1, 1e-2, 1e-3]
        },
        'alpha': {
            'values': [0.1, 0.5, 1.0]
        },
        'num_workers': {
            'values': [0, 2, 4, 8]
        }
    }
}
# os.environ["WANDB_NOTEBOOK_NAME"] = "NNDL_Mini_Project_MRL.ipynb"
os.environ["WANDB_NOTEBOOK_NAME"] = "NNDL_Mini_Project_MRL.ipynb"
wandb.login()
sweep_id = wandb.sweep(sweep=sweep_config, project="NNDL_Mini_Project")

def wrapper_train():
    with wandb.init() as run:
        config = run.config
        args.dataset_name = config.dataset_name
        args.encoder_model_name = config.encoder_model_name
        args.batch_size = config.batch_size
        args.learning_rate = config.learning_rate
        args.weight_decay = config.weight_decay
        args.alpha = config.alpha
        args.num_workers = config.num_workers

        train_relic(args)

def main():
    wandb.agent(sweep_id, function=wrapper_train)


if __name__ == "__main__":
    main()


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mnaveen-indluru[0m ([33mnaveen-indluru-university-of-central-missouri[0m). Use [1m`wandb login --relogin`[0m to force relogin


Create sweep with ID: kbzwfo0p
Sweep URL: https://wandb.ai/naveen-indluru-university-of-central-missouri/NNDL_Mini_Project/sweeps/kbzwfo0p


[34m[1mwandb[0m: Agent Starting Run: 8b8a3s8p with config:
[34m[1mwandb[0m: 	alpha: 1
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	dataset_name: cifar10
[34m[1mwandb[0m: 	encoder_model_name: resnet18
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	num_workers: 2
[34m[1mwandb[0m: 	weight_decay: 0.1


Final nesting_dims: [8, 16, 32, 64]
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
