In [1]:
import os
import time
import argparse

import torchvision
import torch
import torch.nn as nn

from util import AverageMeter
from encoder import SmallAlexNet
from align_uniform import align_loss, uniform_loss_prelog
from tqdm import tqdm
from collections import defaultdict
import copy

import matplotlib.pyplot as plt


class TwoAugUnsupervisedDatasetLbl(torch.utils.data.Dataset):
    r"""Returns two augmentation and no labels."""

    def __init__(self, dataset, transform, lblmap=None):
        self.dataset = dataset
        self.transform = transform
        self.lblmap = copy.deepcopy(lblmap)

    def __getitem__(self, index):
        image, lbl = self.dataset[index]
        lbl2return = lbl if self.lblmap is None else self.lblmap[lbl]
        return self.transform(image), self.transform(image), lbl2return

    def __len__(self):
        return len(self.dataset)

def parse_option():
    parser = argparse.ArgumentParser('STL-10 Representation Learning with Alignment and Uniformity Losses')

    parser.add_argument('--align_w', type=float, default=1, help='Alignment loss weight')
    parser.add_argument('--unif_w', type=float, default=1, help='Uniformity loss weight')
    parser.add_argument('--align_alpha', type=float, default=2, help='alpha in alignment loss')
    parser.add_argument('--unif_t', type=float, default=2, help='t in uniformity loss')

    parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--iter', type=int, default=0, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=None,
                        help='Learning rate. Default is linear scaling 0.12 per 256 batch size')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
    parser.add_argument('--lr_decay_epochs', default=[155, 170, 185], nargs='*', type=int,
                        help='When to decay learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='L2 weight decay')
    parser.add_argument('--feat_dim', type=int, default=128, help='Feature dimensionality')

    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers to use')
    parser.add_argument('--log_interval', type=int, default=40, help='Number of iterations between logs')
    parser.add_argument('--gpus', default=[0], nargs='*', type=int,
                        help='List of GPU indices to use, e.g., --gpus 0 1 2 3')

    parser.add_argument('--data_folder', type=str, default='./data', help='Path to data')
    parser.add_argument('--result_folder', type=str, default='./results', help='Base directory to save model')

    opt = parser.parse_args("")

    if opt.lr is None:
        opt.lr = 0.12 * (opt.batch_size / 256)

    opt.gpus = list(map(lambda x: torch.device('cuda', x), opt.gpus))

    opt.save_folder = os.path.join(
        opt.result_folder,
        f"base_200_sideinformation_align{opt.align_w:g}alpha{opt.align_alpha:g}_unif{opt.unif_w:g}t{opt.unif_t:g}_iter{opt.iter}"
    )
    os.makedirs(opt.save_folder, exist_ok=True)

    return opt


opt = parse_option()

In [2]:
transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(64, scale=(0.08, 1)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        torchvision.transforms.RandomGrayscale(p=0.2),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
    ])

old_lbls = list(range(10))
labels_2_keep = [0,1]

old2new = {}
count = 0
for old_lbl in old_lbls:
    if old_lbl in labels_2_keep: 
        old2new[old_lbl] = count
        count += 1

for old_lbl in old_lbls:
    if old_lbl not in labels_2_keep: 
        old2new[old_lbl] = count

new_lbls = list(range(count+1))

def get_data_loader(opt):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(64, scale=(0.08, 1)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        torchvision.transforms.RandomGrayscale(p=0.2),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
    ])
    dataset = TwoAugUnsupervisedDatasetLbl(
        torchvision.datasets.STL10(opt.data_folder, 'train', download=True), 
        transform=transform, 
        lblmap=old2new )
    
    return torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, num_workers=opt.num_workers,
                                       shuffle=True, pin_memory=True)


print(f'Optimize: {opt.align_w:g} * loss_align(alpha={opt.align_alpha:g}) + {opt.unif_w:g} * loss_uniform(t={opt.unif_t:g})')

torch.cuda.set_device(opt.gpus[0])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

encoder = SmallAlexNet(feat_dim=opt.feat_dim).to(opt.gpus[0])

optim = torch.optim.SGD(encoder.parameters(), lr=opt.lr,
                        momentum=opt.momentum, weight_decay=opt.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, gamma=opt.lr_decay_rate,
                                                 milestones=opt.lr_decay_epochs)

loader = get_data_loader(opt)
align_meter = AverageMeter('align_loss')
unif_meter = AverageMeter('uniform_loss')
loss_meter = AverageMeter('total_loss')
it_time_meter = AverageMeter('iter_time')

for epoch in range(opt.epochs):
    align_meter.reset()
    unif_meter.reset()
    loss_meter.reset()
    it_time_meter.reset()
    t0 = time.time()
    for ii, (im_x, im_y, lbl) in enumerate(loader):
        optim.zero_grad()
        x, y = encoder(torch.cat([im_x.to(opt.gpus[0]), im_y.to(opt.gpus[0])])).chunk(2)
        
        align_loss_val = align_loss(x, y, alpha=opt.align_alpha)
        # group according to new_lbls

        z = torch.cat( [x, y])
        lbl_z = torch.cat([lbl, lbl])
        unif_losses = torch.cat([uniform_loss_prelog(z[lbl_z==new_lbl]) for new_lbl in new_lbls])
        unif_loss_val = torch.log( torch.mean(unif_losses) )
        
        loss = align_loss_val * opt.align_w + unif_loss_val * opt.unif_w
        align_meter.update(align_loss_val, x.shape[0])
        unif_meter.update(unif_loss_val)
        loss_meter.update(loss, x.shape[0])
        loss.backward()
        optim.step()
        it_time_meter.update(time.time() - t0)
        if ii % opt.log_interval == 0:
            print(f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(loader)}\t" +
                  f"{align_meter}\t{unif_meter}\t{loss_meter}\t{it_time_meter}")
        t0 = time.time()
    scheduler.step()

ckpt_file = os.path.join(opt.save_folder, 'encoder.pth')
torch.save(encoder.state_dict(), ckpt_file)
print(f'Saved to {ckpt_file}')

Optimize: 1 * loss_align(alpha=2) + 1 * loss_uniform(t=2)
Files already downloaded and verified
Epoch 0/200	It 0/20	align_loss 1.246597 (1.246597)	uniform_loss -2.628047 (-2.628047)	total_loss -1.381450 (-1.381450)	iter_time 8.195710 (8.195710)
Epoch 1/200	It 0/20	align_loss 1.356540 (1.356540)	uniform_loss -3.444576 (-3.444576)	total_loss -2.088037 (-2.088037)	iter_time 0.787861 (0.787861)
Epoch 2/200	It 0/20	align_loss 1.153151 (1.153151)	uniform_loss -3.463474 (-3.463474)	total_loss -2.310323 (-2.310323)	iter_time 0.743879 (0.743879)
Epoch 3/200	It 0/20	align_loss 1.125241 (1.125241)	uniform_loss -3.552011 (-3.552011)	total_loss -2.426770 (-2.426770)	iter_time 0.779960 (0.779960)
Epoch 4/200	It 0/20	align_loss 1.091697 (1.091697)	uniform_loss -3.556884 (-3.556884)	total_loss -2.465187 (-2.465187)	iter_time 0.708362 (0.708362)
Epoch 5/200	It 0/20	align_loss 1.070797 (1.070797)	uniform_loss -3.595987 (-3.595987)	total_loss -2.525190 (-2.525190)	iter_time 0.921899 (0.921899)
Epoch 6/20

In [3]:
"""
    Here we  do the linear evaluation, the old labels are provided to the linear objective as one hot
"""

'\n    Here we  do the linear evaluation, the old labels are provided to the linear objective as one hot\n'

In [4]:
class DatasetModifiedLbl(torch.utils.data.Dataset):
    r"""Returns two augmentation and no labels."""

    def __init__(self, dataset, transform, lblmap=None):
        self.dataset = dataset
        self.transform = transform
        self.lblmap = copy.deepcopy(lblmap)

    def __getitem__(self, index):
        image, lbl = self.dataset[index]
        lbl2return = lbl if self.lblmap is None else self.lblmap[lbl]
        return self.transform(image), lbl2return

    def __len__(self):
        return len(self.dataset)

In [5]:
class DatasetModifiedLblandLbl(torch.utils.data.Dataset):
    r"""Returns two augmentation and no labels."""

    def __init__(self, dataset, transform, lblmap):
        self.dataset = dataset
        self.transform = transform
        self.lblmap = copy.deepcopy(lblmap)

    def __getitem__(self, index):
        image, lbl = self.dataset[index]
        return self.transform(image), self.lblmap[lbl], lbl

    def __len__(self):
        return len(self.dataset)

In [6]:
def get_data_loaders(opt):
    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(64, scale=(0.08, 1)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
        torchvision.transforms.RandomHorizontalFlip()
    ])
    val_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(70),
        torchvision.transforms.CenterCrop(64),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.44087801806139126, 0.42790631331699347, 0.3867879370752931),
            (0.26826768628079806, 0.2610450402318512, 0.26866836876860795),
        ),
    ])
    train_dataset = DatasetModifiedLblandLbl( torchvision.datasets.STL10(opt.data_folder, 'train', download=True, transform=train_transform) )
    val_dataset =  DatasetModifiedLblandLbl( torchvision.datasets.STL10(opt.data_folder, 'test', transform=val_transform) )
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size,
                                               num_workers=opt.num_workers, shuffle=True, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.batch_size,
                                             num_workers=opt.num_workers, pin_memory=True)
    return train_loader, val_loader


def validate(opt, encoder, classifier, val_loader):
    correct = 0
    with torch.no_grad():
        for images, labels_mod, labels_act in val_loader:
            pred = classifier( (torch.cat( encoder(images.to(opt.gpus[0]), layer_index=opt.layer_index).flatten(1), torch.nn.functional.one_hot(labels_mod.to(opt.gpus[0])), dim=1)).argmax(dim=1)
            correct += (pred.cpu() == labels).sum().item()
    return correct / len(val_loader.dataset)

In [7]:
opt.gpu=opt.gpus[0]

In [None]:
opt = parse_option()

opt.gpu=opt.gpus[0]
torch.cuda.set_device(opt.gpus)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

encoder = SmallAlexNet(feat_dim=opt.feat_dim).to(opt.gpus[0])
encoder.eval()
train_loader, val_loader = get_data_loaders(opt)

with torch.no_grad():
    sample, _ = train_loader.dataset[0]
    eval_numel = encoder(sample.unsqueeze(0).to(opt.gpus[0]), layer_index=opt.layer_index).numel()
print(f'Feature dimension: {eval_numel}')

try:
    encoder.load_state_dict(torch.load(opt.encoder_checkpoint, map_location=opt.gpus[0]))
except TypeError:
    try:
        encoder = torch.load(opt.encoder_checkpoint)[-1].to(opt.gpus[0]).module
    except:
        encoder = torch.load(opt.encoder_checkpoint)[-1].to(opt.gpus[0])
print(f'Loaded checkpoint from {opt.encoder_checkpoint}')

classifier = nn.Linear(eval_numel, 10).to(opt.gpus[0])

optim = torch.optim.Adam(classifier.parameters(), lr=opt.lr, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, gamma=opt.lr_decay_rate,
                                                 milestones=opt.lr_decay_epochs)

loss_meter = AverageMeter('loss')
it_time_meter = AverageMeter('iter_time')
for epoch in range(opt.epochs):
    loss_meter.reset()
    it_time_meter.reset()
    t0 = time.time()
    for ii, (images, labels_mod, labels) in enumerate(train_loader):
        optim.zero_grad()
        with torch.no_grad():
            feats = encoder(images.to(opt.gpus[0]), layer_index=opt.layer_index).flatten(1)
        logits = classifier(torch.cat(feats, torch.nn.functional.one_hot(labels_mod.to(opt.gpus[0])),dim=1))
        loss = F.cross_entropy(logits, labels.to(opt.gpus[0]))
        loss_meter.update(loss, images.shape[0])
        loss.backward()
        optim.step()
        it_time_meter.update(time.time() - t0)
        if ii % opt.log_interval == 0:
            print(f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(train_loader)}\t{loss_meter}\t{it_time_meter}")
        t0 = time.time()
    scheduler.step()
    val_acc = validate(opt, encoder, classifier, val_loader)
    print(f"Epoch {epoch}/{opt.epochs}\tval_acc {val_acc*100:.4g}%")