In [None]:
!git clone https://github.com/huynguyentran/SupContrast.git

In [None]:
!pip install tensorboard-logger

In [None]:
import sys
# sys.path.append("/kaggle/working")
sys.path.append("SupContrast")

In [None]:
from __future__ import print_function
import os
import argparse
import time
import math
import importlib

import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets
# import SupContrast.networks.resnet_big
# importlib.reload(SupContrast.networks.resnet_big)

import networks.resnet_big
importlib.reload(networks.resnet_big)

from SupContrast.util import TwoCropTransform, AverageMeter
from SupContrast.util import adjust_learning_rate, warmup_learning_rate
from SupContrast.util import set_optimizer, save_model
from SupContrast.networks.resnet_big import SupConViT, SupConResNet, CrossAttViT
from SupContrast.losses import SupConLoss



try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=16,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.05,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100', 'path'], help='dataset')
    parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
    parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
    parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
    parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop')

    # method
    parser.add_argument('--method', type=str, default='SupCon',
                        choices=['SupCon', 'SimCLR'], help='choose method')

    # temperature
    parser.add_argument('--temp', type=float, default=0.07,
                        help='temperature for loss function')
    
    parser.add_argument('--optimizer_name', type=str, default="adamw",
                        help='optimizer of the funciton')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--syncBN', action='store_true',
                        help='using synchronized batch normalization')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
    parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')

    opt = parser.parse_args()

    # check if dataset is path that passed required arguments
    if opt.dataset == 'path':
        assert opt.data_folder is not None \
            and opt.mean is not None \
            and opt.std is not None

    # set the path according to the environment
    if opt.data_folder is None:
        opt.data_folder = './datasets/'
    opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
    opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\
        format(opt.method, opt.dataset, opt.model, opt.learning_rate,
               opt.weight_decay, opt.batch_size, opt.temp, opt.trial)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.batch_size > 256:
        opt.warm = True
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
    if not os.path.isdir(opt.tb_folder):
        os.makedirs(opt.tb_folder)

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)

    return opt


def set_loader(opt):
    # construct data loader
    if opt.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif opt.dataset == 'path':
        mean = eval(opt.mean)
        std = eval(opt.std)
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    if opt.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
    elif opt.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
    elif opt.dataset == 'path':
        train_dataset = datasets.ImageFolder(root=opt.data_folder,
                                            transform=TwoCropTransform(train_transform))
    else:
        raise ValueError(opt.dataset)
    selected_classes = {"Pseudopapilledema", "Papilledema"}
    print(selected_classes)
    print(train_dataset.classes)
    filtered_indices = [i for i, (_, label) in enumerate(train_dataset) if train_dataset.classes[label] in selected_classes]
    filtered_dataset = torch.utils.data.Subset(train_dataset, filtered_indices)
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        filtered_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
        num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
    seen_classes = set()
    for images, labels in train_loader:
        for label in labels:
            seen_classes.add(train_dataset.classes[label.item()])  # Add the class name
    print("Classes present in this train_loader:", seen_classes)
    return train_loader


def set_model(opt):
    model = CrossAttViT(name=opt.model)
    # if 'vit' in opt.model:
    #     model = SupConViT(name=opt.model)
    # else:
    #     model = SupConResNet(name=opt.model)
    criterion = SupConLoss(temperature=opt.temp)

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        device = "cuda"
        model = model.to(device)
        criterion = criterion.to(device)
        cudnn.benchmark = True

    return model, criterion

In [None]:

def train(train_loader, model, criterion, optimizer, epoch, opt):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        if idx ==3:
            break
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        # features = model(images)
  
        # f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        f1,f2 = model(images, bsz)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)


   
        
        if opt.method == 'SupCon':
            loss = criterion(features, labels)
        elif opt.method == 'SimCLR':
            loss = criterion(features)
        else:
            raise ValueError('contrastive method not supported: {}'.
                             format(opt.method))
       
        # print(f"Loss: {loss.item()}")

        
        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg


def main(opt):
    # opt = parse_option()

    # build data loader
    train_loader = set_loader(opt)

    # build model and criterion
    model, criterion = set_model(opt)

    # build optimizer
    optimizer = set_optimizer(opt, model)

    # tensorboard
    # logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss = train(train_loader, model, criterion, optimizer, epoch, opt)
        time2 = time.time()
        print('epoch {}, loss {:.4f}, total time {:.2f}s'.format(epoch, loss, time2 - time1))

        # # tensorboard logger
        # logger.log_value('loss', loss, epoch)
        # logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)

    # save the last model
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)

In [None]:
import os
import math
import argparse
import time
import torch

# Define an Opt class to hold the configuration
class Opt:
    def __init__(self, **kwargs):
        # Default values as per the parse_option function
        self.print_freq = kwargs.get('print_freq', 10)
        self.save_freq = kwargs.get('save_freq', 50)
        self.batch_size = kwargs.get('batch_size', 256)
        self.num_workers = kwargs.get('num_workers', 16)
        self.epochs = kwargs.get('epochs', 1000)

        # optimization
        self.learning_rate = kwargs.get('learning_rate', 0.05)
        self.lr_decay_epochs = kwargs.get('lr_decay_epochs', '700,800,900')
        self.lr_decay_rate = kwargs.get('lr_decay_rate', 0.1)
        self.weight_decay = kwargs.get('weight_decay', 1e-4)
        self.momentum = kwargs.get('momentum', 0.9)

        # model dataset
        self.model = kwargs.get('model', 'resnet50')
        self.dataset = kwargs.get('dataset', 'cifar10')
        self.mean = kwargs.get('mean', None)
        self.std = kwargs.get('std', None)
        self.data_folder = kwargs.get('data_folder', None)
        self.size = kwargs.get('size', 32)

        # method
        self.method = kwargs.get('method', 'SupCon')

        # temperature
        self.temp = kwargs.get('temp', 0.07)

        # other settings
        self.cosine = kwargs.get('cosine', False)
        self.syncBN = kwargs.get('syncBN', False)
        self.warm = kwargs.get('warm', False)
        self.trial = kwargs.get('trial', '0')
        self.optimizer_name = kwargs.get('optimizer_name', 'adamw')

        # Check if dataset is 'path' and passed required arguments
        if self.dataset == 'path':
            assert self.data_folder is not None and self.mean is not None and self.std is not None

        # Set the path according to the environment
        if self.data_folder is None:
            self.data_folder = './datasets/'

        self.model_path = './save/SupCon/{}_models'.format(self.dataset)
        self.tb_path = './save/SupCon/{}_tensorboard'.format(self.dataset)

        # Handle lr_decay_epochs
        iterations = self.lr_decay_epochs.split(',')
        self.lr_decay_epochs = [int(it) for it in iterations]

        # Prepare model name
        self.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.format(
            self.method, self.dataset, self.model, self.learning_rate,
            self.weight_decay, self.batch_size, self.temp, self.trial)

        if self.cosine:
            self.model_name = '{}_cosine'.format(self.model_name)

        # Warm-up for large-batch training
        if self.batch_size > 256:
            self.warm = True
        if self.warm:
            self.model_name = '{}_warm'.format(self.model_name)
            self.warmup_from = 0.01
            self.warm_epochs = 10
            if self.cosine:
                eta_min = self.learning_rate * (self.lr_decay_rate ** 3)
                self.warmup_to = eta_min + (self.learning_rate - eta_min) * (
                        1 + math.cos(math.pi * self.warm_epochs / self.epochs)) / 2
            else:
                self.warmup_to = self.learning_rate

        self.tb_folder = os.path.join(self.tb_path, self.model_name)
        if not os.path.isdir(self.tb_folder):
            os.makedirs(self.tb_folder)

        self.save_folder = os.path.join(self.model_path, self.model_name)
        if not os.path.isdir(self.save_folder):
            os.makedirs(self.save_folder)

# torch.cuda.empty_cache()
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

opt = Opt(
    batch_size=8,
    learning_rate=1e-5, 
    print_freq=100,
    save_freq=100,
    temp=0.05,
    cosine=True,
    num_workers=4,
    dataset="path",
    data_folder="/kaggle/input/identification-of-pseudopapilledema",
    mean= "(0.4914, 0.4822, 0.4465)",  
    std="(0.2675, 0.2565, 0.2761)",   
    epochs=100,
    # lr_decay_epochs="5,7,9",  
    # lr_decay_rate=0.1,
    # weight_decay=5e-2, 
    size=224,
    optimizer_name='adamw',
    model='vit_l_dino',
    warm=True,  
)

main(opt)

In [None]:
# opt = Opt(
#     batch_size=32,
#     learning_rate=0.5, 
#     save_freq =10,
#     temp=0.1,
#     cosine=True,
#     num_workers=4,
#     dataset="path",
#     data_folder="/kaggle/input/identification-of-pseudopapilledema",
#     mean="(0.4914, 0.4822, 0.4465)",
#     std="(0.2675, 0.2565, 0.2761)",
#     method="SupCon",
#     size = 224,
#     model='vit_s_dino',
# )



# opt = Opt(
#     batch_size=32,
#     learning_rate=0.5, 
#     save_freq =10,
#     temp=0.1,
#     cosine=True,
#     num_workers=4,
#     dataset="path",
#     data_folder="/kaggle/input/identification-of-pseudopapilledema",
#     mean="(0.4914, 0.4822, 0.4465)",
#     std="(0.2675, 0.2565, 0.2761)",
#     method="SupCon",
#     size = 32,
#     model='resnet50',
# )

# main(opt)