In [9]:
import torch
from typing import Iterable


class MergeLoader(Iterable):
    def __init__(self, datasets, batch_size, shuffle):
        '''
        datasets: list of datasets -- all has to be iterable,
            resettable, and same shuffling properties
        batch_size: int -- size of the batch
        shuffle: bool -- whether to shuffle the merged dataset
        '''
        self.datasets = datasets
        self.num_datasets = len(datasets)
        self.bsz = batch_size
        self.shuffle = shuffle
        self.buffers = None


    def init_buffers(self):
        # if this yields an error, your iterator is bad.
        if self.shuffle:
            # this now represents the buffer itself
            self.iterators = [iter(d) for d in self.datasets]
            self.buffers = [next(d) for d in self.iterators]
            self.exhausted = [False for _ in self.datasets]
        else:
            self.current_dataset = 0
            self.iterator = iter(self.datasets[0])
            self.buffers = next(self.iterator)


    def __iter__(self):
        return self
    
    
    def unshuffle_pump(self, is_returning=False):
        # if is_returning is True, the buffer is being prepared to return
        # -> now can raise StopIteration
        # make sure the buffer is at least bsz, or the dataset is exhausted
        while self.buffers[0].size(0) < self.bsz and self.current_dataset < self.num_datasets:
            try:
                _data, _label = next(self.iterator)
                self.buffers[0] = torch.cat([self.buffers[0], _data])
                self.buffers[1] = torch.cat([self.buffers[1], _label])
            except StopIteration:
                self.current_dataset += 1
                if self.current_dataset < self.num_datasets:
                    self.iterator = iter(self.datasets[self.current_dataset])
                else:
                    break

        # if the buffer is still empty, the dataset is exhausted
        if self.buffers[0].size(0) == 0 and is_returning:
            self.buffers = None
            raise StopIteration
        

    def shuffle_pump(self, idx, req_size):
        # make sure the buffer is at least bsz, or the dataset is exhausted
        # if so, update self.exhausted accordingly
        while self.buffers[idx][0].size(0) < req_size and not self.exhausted[idx]:
            try:
                _data, _label = next(self.iterators[idx])
                self.buffers[idx][0] = torch.cat([self.buffers[idx][0], _data])
                self.buffers[idx][1] = torch.cat([self.buffers[idx][1], _label])
            except StopIteration:
                self.exhausted[idx] = True
        

    def __next__(self):
        # initialize buffer if necessary
        if self.buffers is None:
            self.init_buffers()

        if self.shuffle:
            batch_placeholder = None
            count_left = self.bsz

            while count_left > 0 and not all(self.exhausted):
                for i in range(self.num_datasets):
                    # get a random number of data from each dataset
                    if i < self.num_datasets - 1:
                        idx_count = (torch.randint(i, self.num_datasets, (count_left,)) == i).sum().item()
                    else:
                        idx_count = count_left

                    if idx_count == 0:
                        continue
                    
                    self.shuffle_pump(i, idx_count)
                    data_to_add = self.buffers[i][0][:idx_count]
                    label_to_add = self.buffers[i][1][:idx_count]
                    self.buffers[i] = [self.buffers[i][0][idx_count:], self.buffers[i][1][idx_count:]]

                    count_left -= data_to_add.size(0)

                    if batch_placeholder is None:
                        batch_placeholder = data_to_add, label_to_add
                    else:
                        batch_placeholder = (
                            torch.cat([batch_placeholder[0], data_to_add]),
                            torch.cat([batch_placeholder[1], label_to_add])
                        )

            if count_left == self.bsz:
                self.buffers = None
                raise StopIteration
            else:
                # shuffle the batch
                idxs = torch.randperm(batch_placeholder[0].size(0))
                return batch_placeholder[0][idxs], batch_placeholder[1][idxs]

        else:
            # if not shuffling, just yield from each dataset sequentially
            # the assumption is that either the buffer is long enough for bsz,
            # or the dataset is exhausted (after this iteration)
            self.unshuffle_pump(is_returning=True)
            ret = (self.buffers[0][:self.bsz], self.buffers[1][:self.bsz])
            self.buffers = [self.buffers[0][self.bsz:], self.buffers[1][self.bsz:]]
            self.unshuffle_pump()
            return ret

In [10]:
from pathlib import Path
from src.utils import set_seed
from src.train_utils import AverageMeter


def evaluate(trainloader, testloader, configs):
    if configs.seed:
        set_seed(configs.seed)

    if isinstance(trainloader, (list, tuple)):
        trainloader = MergeLoader(trainloader, configs.bsz, True)
    if isinstance(testloader, (list, tuple)):
        testloader = MergeLoader(testloader, configs.bsz, False)

    model = configs.model
    optimizer = configs.optimizer(model.parameters(), lr=configs.lr, **(configs.optimizer_kwargs if configs.optimizer_kwargs else {}))
    scheduler = configs.scheduler(optimizer, **configs.scheduler_kwargs) if configs.scheduler else None
    criterion = configs.criterion
    device = configs.device if configs.device else 'cuda' if torch.cuda.is_available() else 'cpu'

    # train the model with trainloader
    train_losses = []
    model.train()
    for _ in range(configs.epochs):
        avg_loss = AverageMeter()
        for data, label in trainloader:
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            if scheduler:
                scheduler.step()
            avg_loss.update(loss.item(), data.size(0))
        train_losses.append(avg_loss.avg)

    # evaluate the model with testloader
    model.eval()
    # keep track of whatever here
    acc_meter = AverageMeter()

    for data, label in testloader:
        data, label = data.to(device), label.to(device)
        output = model(data)
        acc = (output.argmax(1) == label).float().mean().item()
        acc_meter.update(acc, data.size(0))

    # write everything to a file
    savedir = Path(configs.save_dir)
    savedir = savedir / configs.name / configs.timestamp
    savedir.mkdir(exist_ok=True, parents=True)

    with open(savedir / f'{configs.name}_results.txt', 'w') as f:
        f.write(f'{train_losses=}\n')
        f.write(f'acc={acc_meter.avg}\n')
    
    # save the model
    # delete configs.model
    del configs.model

    torch.save({
            'ckpt': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'configs': vars(configs),
        },
        savedir / f'{configs.name}_model.pth'
    )

    print(f'[+] {configs.name}')
    print(f'    - avg_train_loss={sum(train_losses)/len(train_losses)}')
    print(f'    - avg_test_acc={acc_meter.avg}')

    return train_losses, acc_meter.avg

In [11]:
# define dataset, copied straight from the acc90 notebook

import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define a pytorch dataloader for this dataset
class HAM10000(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        # Load data and get label
        X = Image.open(self.df['path'][index])
        y = torch.tensor(int(self.df['cell_type_idx'][index]))

        if self.transform:
            X = self.transform(X)

        return X, y
    

data_dir = '/home/ngoc/.cache/kagglehub/datasets/kmader/skin-cancer-mnist-ham10000/versions/2'
df_train = pd.read_pickle(data_dir+'/train_data.pkl')
df_val = pd.read_pickle(data_dir+'/val_data.pkl')

normMean = [0.763033, 0.5456458, 0.5700401]
normStd = [0.14092815, 0.15261315, 0.16997056]

norm_mean = normMean
norm_std = normStd
input_size = 64

# define the transformation of the train images.
train_transform = transforms.Compose([transforms.Resize((input_size,input_size)),transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),transforms.RandomRotation(20),
                                      transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
                                        transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])
# define the transformation of the val images.
val_transform = transforms.Compose([transforms.Resize((input_size,input_size)), transforms.ToTensor(),
                                    transforms.Normalize(norm_mean, norm_std)])


# Define the training set using the table train_df and using our defined transitions (train_transform)
training_set = HAM10000(df_train, transform=train_transform)
train_loader = DataLoader(training_set, batch_size=128, shuffle=True, num_workers=54)
# Same for the validation set:
validation_set = HAM10000(df_val, transform=train_transform)
val_loader = DataLoader(validation_set, batch_size=128, shuffle=False, num_workers=54)

In [None]:
from time import time
from src.train_utils import get_cls_model
from argparse import Namespace
Namespace.__getattr__ = lambda _1, _2: {}

# default configurations
model = get_cls_model('resnet', num_classes=10, feature_extract=False, use_pretrained=True).to('cuda')

configs = {
    'name': 'test',                     # experiment name
    'save_dir': './sk-lesion-results',  # base directory to save results
    'model': model,
    'optimizer': torch.optim.SGD,
    'scheduler': None,                  # can change here
    'criterion': torch.nn.CrossEntropyLoss(),
    'lr': 1e-3,
    'bsz': 256,
    'epochs': 10,
    'optimizer_kwargs': {
        'momentum': 0.9,
        'weight_decay': 5e-4,
    },
    'device': 'cuda',
    'seed': 42
}
configs = Namespace(**configs)
configs.timestamp = str(int(time()))

# append generated data loaders here -- note that whatever appended has to be iterable!
trainloader = [train_loader]
testloader = [val_loader]

# set_seed is run at the beginning of evaluate
evaluate(trainloader, testloader, configs)