IMPORTS

In [1]:
# !pip install torchmetrics
# !pip install tqdm
# !pip install terminaltables
# !pip install autocast
# # !pip install ch
# !pip install pytorch-lightning
# !pip install hydra-core

In [2]:
import pickle
from torch.utils.data import DataLoader, Dataset
import os
import torch
from torchvision import transforms
from PIL import Image

import sys
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
import torch.nn.functional as F
import torch.distributed as dist
from torchvision import models
import torchmetrics
import numpy as np
from tqdm import tqdm
import os
import time
import json
from uuid import uuid4
from typing import List
from pathlib import Path
# from fastargs import get_current_config
# from fastargs.decorators import param
# from fastargs import Param, Section
# from fastargs.validation import And, OneOf

# from ffcv.pipeline.operation import Operation
# from ffcv.loader import Loader, OrderOption
# from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
#     RandomHorizontalFlip, ToTorchImage
# from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
#     RandomResizedCropRGBImageDecoder
# from ffcv.fields.basics import IntDecoder

MATRYOKSHA FUNCTIONS

In [3]:

'''
Loss function for Matryoshka Representation Learning
'''
import torch
import torch.nn as nn

class Matryoshka_CE_Loss(nn.Module):
    def __init__(self, relative_importance=None, **kwargs):
        super(Matryoshka_CE_Loss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(**kwargs)
        self.relative_importance = relative_importance

    def forward(self, output, target):
        losses = torch.stack([self.criterion(output_i, target) for output_i in output])
        rel_importance = torch.ones_like(losses) if self.relative_importance is None else torch.tensor(self.relative_importance)
        weighted_losses = rel_importance * losses
        return weighted_losses.sum()

class MRL_Linear_Layer(nn.Module):
    def __init__(self, nesting_list, num_classes=1000, efficient=False, **kwargs):
        super(MRL_Linear_Layer, self).__init__()
        self.nesting_list = nesting_list
        self.num_classes = num_classes
        self.efficient = efficient
        if self.efficient:
            setattr(self, f'nesting_classifier_{0}', nn.Linear(nesting_list[-1], self.num_classes, **kwargs))
        else:
            for i, num_feat in enumerate(self.nesting_list):
                setattr(self, f'nesting_classifier_{i}', nn.Linear(num_feat, self.num_classes, **kwargs))

    def reset_parameters(self):
        if self.efficient:
            self.nesting_classifier_0.reset_parameters()
        else:
            for i in range(len(self.nesting_list)):
                getattr(self, f'nesting_classifier_{i}').reset_parameters()

    def forward(self, x):
        nesting_logits = ()
        for i, num_feat in enumerate(self.nesting_list):
            if self.efficient:
                nesting_logits += (getattr(self, f'nesting_classifier_{0}')(x[:, :num_feat]),)
            else:
                nesting_logits += (getattr(self, f'nesting_classifier_{i}')(x[:, :num_feat]),)
        return nesting_logits

class FixedFeatureLayer(nn.Linear):
    def __init__(self, in_features, out_features, **kwargs):
        super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs)

    def forward(self, x):
        if not (self.bias is None):
            out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias
        else:
            out = torch.matmul(x[:, :self.in_features], self.weight.t())
        return out

nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
fc_layer = MRL_Linear_Layer(nesting_list, num_classes=1000, efficient=True)

INPUTS


In [4]:
from torchvision.models import ResNet18_Weights
weights = ResNet18_Weights.DEFAULT  # Define weights based on model

'''
This code is directly taken from FFCV-Imagenet https://github.com/libffcv/ffcv-imagenet
and modified for MRL purpose.
'''
sys.path.append("../") # adding root folder to the path

torch.backends.cudnn.benchmark = True
torch.autograd.profiler.emit_nvtx(False)
torch.autograd.profiler.profile(False)

config_file = 'rn50_configs/rn50_40_epochs.yaml'
model_fixed_feature = 2048
# train_dataset = os.environ['WRITE_DIR'] + '/train_500_0.50_90.ffcv'
# val_dataset = os.environ['WRITE_DIR'] + '/val_500_uncompressed.ffcv'
num_workers = 12
in_memory = True
logging_folder = 'trainlogs'
log_level = 0
world_size = 2
distributed = False
learning_rate = 0.425

arch='resnet18'
pretrained=0
efficient=0
mrl=0
nesting_start=3
fixed_feature=2048


min_res=160
max_res=160
end_ramp=0
start_ramp=0


step_ratio=0.1
step_length=30
lr_schedule_type='cyclic'
lr=0.5
lr_peak_epoch=2



folder=logging_folder


batch_size=512
resolution=224
lr_tta=1


eval_only=0
path=None
batch_size=512
optimizer='sgd'
momentum=0.9
weight_decay=4e-5
epochs=15
label_smoothing=0.1
distributed=0
use_blurpool=0



address='localhost'
port=12355



DATASET LOADING

In [5]:
class CIFAR10Dataset(Dataset):
    def __init__(self, data, labels, transform=None, label_transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.label_transform = label_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx].transpose((1, 2, 0))  # Transpose to (32, 32, 3)
        label = self.labels[idx]
        image = Image.fromarray(image.astype('uint8'))  # Convert to PIL Image
        if self.transform:
            image = self.transform(image)
        if self.label_transform:
            label = self.label_transform(label)
        return image, label

def load_cifar10_batch(file):
    with open(file, 'rb') as fo:
        batch = pickle.load(fo, encoding='latin1')
    data = batch['data']
    labels = batch['labels']
    data = data.reshape(-1, 3, 32, 32)  # Reshape data
    return data, labels

def load_cifar10_data(data_dir):
    train_data = []
    train_labels = []
    for i in range(1, 6):
        batch_data, batch_labels = load_cifar10_batch(os.path.join(data_dir, f'data_batch_{1}'))
        train_data.append(batch_data)
        train_labels.extend(batch_labels)
    train_data = np.vstack(train_data)
    train_labels = np.array(train_labels)
    test_data, test_labels = load_cifar10_batch(os.path.join(data_dir, 'test_batch'))
    test_data = test_data.reshape(-1, 3, 32, 32)
    test_labels = np.array(test_labels)
    return train_data, train_labels, test_data, test_labels

# data_dir = 'cifar/'  # Modify with actual path
data_dir = 'data/cifar-10-batches-py/'
train_data, train_labels, test_data, test_labels = load_cifar10_data(data_dir)

this_device = torch.device('cuda')

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10), #newly added
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
def label_transform(label):
    # Transform labels to tensor
    # return torch.tensor(label, dtype=torch.long).to(this_device, non_blocking=True)
    return torch.tensor(label, dtype=torch.long)

train_dataset = CIFAR10Dataset(train_data, train_labels, transform=transform_train, label_transform=label_transform)
test_dataset = CIFAR10Dataset(test_data, test_labels, transform=transform_test, label_transform=label_transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
val_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=0)

HELPER FNS

In [6]:
CIFAR_MEAN = np.array([0.485, 0.456, 0.406]) * 255
CIFAR_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224/256

def get_step_lr(epoch, lr=lr, step_ratio=step_ratio, step_length=step_length, epochs=epochs):
    if epoch >= epochs:
        return 0

    num_steps = epoch // step_length
    return step_ratio**num_steps * lr

def get_constant_lr(epoch, lr=lr):
    return lr

def get_cyclic_lr(epoch, lr=lr, epochs=epochs, lr_peak_epoch=lr_peak_epoch):
    xs = [0, lr_peak_epoch, epochs]
    ys = [1e-4 * lr, lr, 0]
    return np.interp([epoch], xs, ys)[0]

class BlurPoolConv2d(torch.nn.Module):
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer('blur_filter', filt)

    def forward(self, x):
        blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
                           groups=self.conv.in_channels, bias=None)
        return self.conv.forward(blurred)


CIFARTrainer

In [7]:
from torchvision.models import ResNet18_Weights
weights = ResNet18_Weights.DEFAULT  # Define weights based on model

class CIFARTrainer:
    def __init__(self, gpu, train_loader = train_loader, val_loader=val_loader,distributed=distributed, efficient=efficient, mrl=mrl, nesting_start=nesting_start, fixed_feature=fixed_feature,
                 this_device=  this_device):
        # self.all_params = get_current_config();
        self.gpu = gpu
        self.efficient = efficient
        self.nesting = (self.efficient or mrl)
        self.nesting_start = nesting_start
        self.nesting_list = [2**i for i in range(self.nesting_start, 12)] if self.nesting else None
        self.fixed_feature=fixed_feature
        self.uid = str(uuid4())
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.this_device = this_device


        if distributed:
            self.setup_distributed()

        self.model, self.scaler = self.create_model_and_scaler()
        # self.model.cuda().half()
        self.create_optimizer()
        self.initialize_logger()


    def setup_distributed(self, address=address, port=port, world_size=world_size):
        os.environ['MASTER_ADDR'] = address
        os.environ['MASTER_PORT'] = port

        dist.init_process_group("nccl", rank=self.gpu, world_size=world_size)
        torch.cuda.set_device(self.gpu)

    def cleanup_distributed(self):
        dist.destroy_process_group()

    def get_lr(self, epoch, lr_schedule_type=lr_schedule_type):
        lr_schedules = {
            'cyclic': get_cyclic_lr,
            'step': get_step_lr,
            'constant': get_constant_lr
        }

        return lr_schedules[lr_schedule_type](epoch)

    # resolution tools
    def get_resolution(self, epoch, min_res=min_res, max_res=max_res, end_ramp=end_ramp, start_ramp=start_ramp):
        assert min_res <= max_res

        if epoch <= start_ramp:
            return min_res

        if epoch >= end_ramp:
            return max_res

        # otherwise, linearly interpolate to the nearest multiple of 32
        interp = np.interp([epoch], [start_ramp, end_ramp], [min_res, max_res])
        final_res = int(np.round(interp[0] / 32)) * 32
        return final_res

    def create_optimizer(self, momentum=momentum, optimizer=optimizer, weight_decay=weight_decay,
                         label_smoothing=label_smoothing):
        assert optimizer == 'sgd'

        # Only do weight decay on non-batchnorm parameters
        all_params = list(self.model.named_parameters())
        bn_params = [v for k, v in all_params if ('bn' in k)]
        other_params = [v for k, v in all_params if not ('bn' in k)]
        param_groups = [{
            'params': bn_params,
            'weight_decay': 0.
        }, {
            'params': other_params,
            'weight_decay': weight_decay
        }]

        self.optimizer = torch.optim.SGD(param_groups, lr=1, momentum=momentum)
        # Adding Nesting Case....
        if self.nesting:
            self.loss = Matryoshka_CE_Loss(label_smoothing=label_smoothing)
        else:
            self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    def train(self, epochs=epochs, log_level=log_level):
        for epoch in range(epochs):
            print("epoch no. ", epoch)
            # res = self.get_resolution(epoch)
            # self.decoder.output_size = (res, res)
            train_loss = self.train_loop(epoch)

            if log_level > 0:
                extra_dict = {
                    'train_loss': train_loss,
                    'epoch': epoch
                }

                self.eval_and_log(extra_dict)

        # self.eval_and_log({'epoch':epoch})
        if self.gpu == 0:
            torch.save(self.model.state_dict(), self.log_folder / 'final_weights.pt')

    def eval_and_log(self, extra_dict={}):
        start_val = time.time()
        if self.nesting:
            stats = self.val_loop_nesting()
        else:
            stats = self.val_loop()
        val_time = time.time() - start_val

        if self.gpu == 0:
            d = {
                'current_lr': self.optimizer.param_groups[0]['lr'], 'val_time': val_time
            }
            for k in stats.keys():
                if k=='loss':
                    continue
                else:
                    d[k]=stats[k]

            self.log(dict(d, **extra_dict))

        return stats

    def create_model_and_scaler(self, arch=arch, weights=weights, distributed=distributed, use_blurpool=use_blurpool):
        '''
        Nesting Start is just the log_2 {smallest dim} unit. In our work we used powers of two, however this part can be changed easily.
        If we do not want to use MRL, we just keep both the efficient and mrl flags to 0
        If we want a fixed feature baseline, then we just change fixed_feature={Rep. Size of your choice}

        NOTE: FFCV Uses Blurpool.
        '''

        scaler = GradScaler()
        model = getattr(models, arch)(weights=weights)

        if self.nesting:
            ff= "MRL-E" if self.efficient else "MRL"
            print(f"Creating classification layer of type :\t {ff}")
            model.fc = MRL_Linear_Layer(self.nesting_list, num_classes=1000, efficient=self.efficient)
        elif self.fixed_feature != 2048:
            print("Using Fixed Features.... ")
            model.fc =  FixedFeatureLayer(self.fixed_feature, 1000)

        def apply_blurpool(mod: torch.nn.Module):
            for (name, child) in mod.named_children():
                if isinstance(child, torch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16):
                    setattr(mod, name, BlurPoolConv2d(child))
                else: apply_blurpool(child)
        if use_blurpool: apply_blurpool(model)

        model = model.to(memory_format=torch.channels_last)
        model = model.to(self.gpu)

        if distributed:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.gpu])

        return model, scaler

    def train_loop(self, epoch, log_level=log_level):
        model = self.model
        model.train()
        losses = []

        lr_start, lr_end = self.get_lr(epoch), self.get_lr(epoch + 1)
        iters = len(self.train_loader)
        lrs = np.interp(np.arange(iters), [0, iters], [lr_start, lr_end])

        iterator = tqdm(self.train_loader)
        for ix, (images, target) in enumerate(iterator):
            images = images.to(self.this_device, non_blocking=True)
            target = target.to(self.this_device, non_blocking=True)
            ### Training start
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lrs[ix]

            self.optimizer.zero_grad(set_to_none=True)
            with autocast():
                # images = images.cuda().half()

                output = self.model(images)
                loss_train = self.loss(output, target)

            self.scaler.scale(loss_train).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            ### Training end

            ### Logging start
            if log_level > 0:
                losses.append(loss_train.detach())

                group_lrs = []
                for _, group in enumerate(self.optimizer.param_groups):
                    group_lrs.append(f'{group["lr"]:.3f}')

                names = ['ep', 'iter', 'shape', 'lrs']
                values = [epoch, ix, tuple(images.shape), group_lrs]
                if log_level > 1:
                    names += ['loss']
                    values += [f'{loss_train.item():.3f}']

                msg = ', '.join(f'{n}={v}' for n, v in zip(names, values))
                iterator.set_description(msg)
            ### Logging end

        if log_level > 0:
            loss = torch.stack(losses).mean().cpu()
            assert not torch.isnan(loss), 'Loss is NaN!'
            return loss.item()

    def val_loop(self, lr_tta=lr_tta):
        model = self.model
        model.eval()

        with torch.no_grad():
            with autocast():
                for images, target in tqdm(self.val_loader):
                    images = images.to(self.this_device, non_blocking=True)
                    target = target.to(self.this_device, non_blocking=True)
                    images = images.cuda().half()
                    output = self.model(images)
                    if lr_tta:
                        output += self.model(torch.flip(images, dims=[3]))

                    for k in ['top_1', 'top_5']:
                        self.val_meters[k](output, target)

                    loss_val = self.loss(output, target)
                    self.val_meters['loss'](loss_val)

        stats = {k: m.compute().item() for k, m in self.val_meters.items()}
        [meter.reset() for meter in self.val_meters.values()]
        return stats


    def val_loop_nesting(self, lr_tta=lr_tta):
        '''
        Since Nested Layers will give a tuple of logits, we have a different subroutine for validation.
        '''

        model = self.model
        model.eval()
        with torch.no_grad():
            with autocast():
                for images, target in tqdm(self.val_loader):
                    images = images.to(self.this_device, non_blocking=True)
                    target = target.to(self.this_device, non_blocking=True)
                    output = self.model(images); output=torch.stack(output, dim=0)

                    if lr_tta:
                        output +=torch.stack(self.model(torch.flip(images, dims=[3])), dim=0) # Just one augmentation.

                    # Logging the accuracies top1/5 for each of nesting...
                    for i in range(len(self.nesting_list)):
                        s = "top_1_{}".format(self.nesting_list[i])
                        self.val_meters[s](output[i], target)
                        s = "top_5_{}".format(self.nesting_list[i])
                        self.val_meters[s](output[i], target)

                    loss_val = self.loss(output, target)
                    self.val_meters['loss'](loss_val)

        stats = {k: m.compute().item() for k, m in self.val_meters.items()}
        [meter.reset() for meter in self.val_meters.values()]
        return stats


    def initialize_logger(self, folder=folder):
        if self.nesting:
            self.val_meters={}
            for i in self.nesting_list:
                self.val_meters['top_1_{}'.format(i)] = torchmetrics.Accuracy(compute_on_step=False).to(self.gpu)

            for i in self.nesting_list:
                self.val_meters['top_5_{}'.format(i)] = torchmetrics.Accuracy(compute_on_step=False, top_k=5).to(self.gpu)

            self.val_meters['loss'] = MeanScalarMetric(compute_on_step=False).to(self.gpu)

        else:
            self.val_meters = {
                'top_1': torchmetrics.Accuracy(task='multiclass', num_classes=10).to(self.gpu),
                'top_5': torchmetrics.Accuracy(task='multiclass', top_k=5, num_classes=10).to(self.gpu),
                'loss': MeanScalarMetric().to(self.gpu)
            }

        if self.gpu == 0:
            folder = (Path(folder) / str(self.uid)).absolute()
            folder.mkdir(parents=True)

            self.log_folder = folder
            self.start_time = time.time()

            print(f'=> Logging in {self.log_folder}')
            # params = {
            #     '.'.join(k): self.all_params[k] for k in self.all_params.entries.keys()
            # }

            # with open(folder / 'params.json', 'w+') as handle:
            #     json.dump(params, handle)

    def log(self, content):
        print(f'=> Log: {content}')
        if self.gpu != 0: return
        cur_time = time.time()
        with open(self.log_folder / 'log', 'a+') as fd:
            fd.write(json.dumps({
                'timestamp': cur_time,
                'relative_time': cur_time - self.start_time,
                **content
            }) + '\n')
            fd.flush()

    @classmethod
    def launch_from_args(cls, distributed=False, world_size=2, eval_only=0):
        if distributed:
            torch.multiprocessing.spawn(cls._exec_wrapper, nprocs=world_size, join=True)
        else:
            cls.exec(0, distributed, eval_only)

    @classmethod
    def _exec_wrapper(cls, *args, **kwargs):
        make_config(quiet=True)
        cls.exec(*args, **kwargs)

    @classmethod
    def exec(cls, gpu, distributed=distributed, eval_only=eval_only, path=None):
        trainer = cls(gpu=gpu)
        if eval_only:
            print("Loading Model....."); ckpt = torch.load(path, map_location="cuda:{}".format(gpu))
            trainer.model.load_state_dict(ckpt); print("Loading Complete!")
            trainer.eval_and_log()
        else:
            trainer.train()

        if distributed:
            trainer.cleanup_distributed()

# Utils
class MeanScalarMetric(torchmetrics.Metric):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.add_state('sum', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, sample: torch.Tensor):
        self.sum += sample.sum()
        self.count += sample.numel()

    def compute(self):
        return self.sum.float() / self.count




In [8]:
# Running
# def make_config(quiet=False):
#     config = get_current_config()
#     parser = ArgumentParser(description='Fast CIFAR training')
#     config.validate(mode='stderr')
#     if not quiet:
#         config.summary()

if __name__ == "__main__":
    # make_config()
    CIFARTrainer.launch_from_args(distributed, world_size, eval_only)

=> Logging in D:\Fall2024\NeuralNetworkDeepLearning\Project\MiniProject\NNDL_MRL_MiniProject\trainlogs\42edc371-7f56-4431-bece-7c7b4df3d60f
epoch no.  0


100%|██████████| 391/391 [00:17<00:00, 22.17it/s]


epoch no.  1


100%|██████████| 391/391 [00:16<00:00, 24.12it/s]


epoch no.  2


100%|██████████| 391/391 [00:15<00:00, 24.59it/s]


epoch no.  3


100%|██████████| 391/391 [00:16<00:00, 24.22it/s]


epoch no.  4


100%|██████████| 391/391 [00:16<00:00, 23.34it/s]


epoch no.  5


100%|██████████| 391/391 [00:16<00:00, 23.96it/s]


epoch no.  6


100%|██████████| 391/391 [00:15<00:00, 24.57it/s]


epoch no.  7


100%|██████████| 391/391 [00:15<00:00, 24.57it/s]


epoch no.  8


100%|██████████| 391/391 [00:16<00:00, 24.43it/s]


epoch no.  9


100%|██████████| 391/391 [00:16<00:00, 23.54it/s]


epoch no.  10


100%|██████████| 391/391 [00:15<00:00, 24.64it/s]


epoch no.  11


100%|██████████| 391/391 [00:15<00:00, 24.75it/s]


epoch no.  12


100%|██████████| 391/391 [00:15<00:00, 24.67it/s]


epoch no.  13


100%|██████████| 391/391 [00:15<00:00, 24.71it/s]


epoch no.  14


100%|██████████| 391/391 [00:15<00:00, 24.67it/s]
