In [1]:
import numpy as np
import dataset
from dataset.datasets import C100Dataset
import csv
import matplotlib.pyplot as plt
import tqdm
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms

## 1. Prepare Dataset

In [2]:
dataset_nl = C100Dataset('./dataset/data/cifar100_nl.csv')
[data_nl_tr_x, data_nl_tr_y, _, _] = dataset_nl.getDataset()
dataset_nl_test = C100Dataset('./dataset/data/cifar100_nl_test.csv')
[data_nl_ts_x, data_nl_ts_y, _, _] = dataset_nl_test.getDataset()

In [4]:
data_nl_tr_x.shape

(49999, 3, 32, 32)

In [5]:
class GaussianBlur(object):
    """blur a single image on CPU"""
    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)

        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)

        return img

In [6]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transforms_train = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    GaussianBlur(kernel_size=int(0.1 * 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [12]:
def make_batch_dataset(N, batch_size, x, y):
    images = []
    labels = []
    for i in range(N//batch_size):
        image_batch = []
        label_batch = []
        for j in range(batch_size):
            image_batch.append(x[i * batch_size + j])
            label_batch.append(y[i * batch_size + j])
        images.append(image_batch)
        labels.append(label_batch)
        
    images = np.array(images).reshape(N//batch_size, batch_size, 3, 32, 32)
    labels = np.array(labels).reshape(N//batch_size, batch_size)
    return images, labels

In [13]:
train_dataset = make_batch_dataset(N=data_nl_tr_y.shape[0], batch_size=128, x=data_nl_tr_x, y = data_nl_tr_y)

## 2. Model Setting

In [17]:
class ResNet_(nn.Module):
    def __init__(self, base_model, out_dim):
        super(ResNet_, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self.get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def get_basemodel(self, model_name):
        model = self.resnet_dict[model_name]
        return model

    def forward(self, x):
        return self.backbone(x)

In [18]:
model = ResNet_(base_model='resnet18', out_dim=128)
device = torch.device("mps")
model = model.to(device)

In [None]:
def _predict(data_loader, model):
    model.eval()
    
    pred_list = []
    label_list = []
    with torch.no_grad():
        for data in data_loader:
            images, labels = data
            images, labels = images.cuda(), labels.cuda()

            outputs = model(images)

            pred_list.append(outputs)
            label_list.append(labels)
            
    return pred_list, label_list

def _accuracy(output, target, topk=(1,)):
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        # import pdb; pdb.set_trace()

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            k_score = (correct_k.mul_(100.0 / batch_size))
            res.append(k_score.item())
        return res
    
def _accuracy_sep(output, target, topk=(1,)):
    num_class = torch.max(target)+1
    mask_head = target < num_class//2
    mask_tail = target >= num_class//2
    
    res_head = _accuracy(output[mask_head], target[mask_head])
    res_tail = _accuracy(output[mask_tail], target[mask_tail])
    return res_head, res_tail
    
def compute_accuracy(data_loader, model, topk=(1,)):
    pred_list, label_list = _predict(data_loader, model)
    pred_list = torch.cat(pred_list, dim=0)
    label_list = torch.cat(label_list, dim=0)
    res = _accuracy(pred_list, label_list, topk)
    res_head, res_tail = _accuracy_sep(pred_list, label_list, topk)
    return res, res_head, res_tail

In [19]:
def train(model, train_dataset, criterion, optimizer, scheduler, EPOCHS):
    for epoch in range(EPOCHS):
        loss_history = []
        model.train()
        for image_batch, label_batch in zip(*train_dataset):
            image_batch = torch.tensor(image_batch)
            image_batch = image_batch
            label_batch = torch.tensor(label_batch)
            image_batch = image_batch.to(device)
            label_batch = label_batch.to(device)
            
            pred = model(image_batch)
            loss = criterion(pred, label_batch)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print(np.mean(loss_history))
        scheduler.step()
    

LR = 0.1
BATCH_SIZE = 128 
MOMENTUM = 0.9
WEIGHT_DECAY = 2e-4
EPOCHS = 200

model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
    
print('# training')
train(model, train_dataset, criterion, optimizer, scheduler, EPOCHS)

# training


KeyboardInterrupt: 