In [27]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import pylab as plt

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, RandomHorizontalFlip, ToTensor, RandomCrop, Normalize
from tqdm.auto import tqdm
# from models.lenet import LeNet, evaluate
from models.resnet import ResNet18, evaluate, train_one_epoch
from utils import seed_everything

In [28]:
transform = Compose([
    RandomCrop(32, padding=4), 
    RandomHorizontalFlip(), 
    ToTensor(), 
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.262)),
])

train_ds = CIFAR10('/tmp', train=True, download=True, transform=transform)
test_ds = CIFAR10('/tmp', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [60]:
from sklearn.model_selection import train_test_split

targets = train_ds.targets
labeled_indices, unlabeled_indices = train_test_split(np.arange(len(train_ds)), test_size=.9992, stratify=targets, shuffle=True, random_state=42)

test_loader = DataLoader(test_ds, batch_size=32)
labeled_loader = DataLoader(train_ds, batch_size=32, sampler=labeled_indices, drop_last=True)
unlabeled_loader = DataLoader(train_ds, batch_size=32, sampler=unlabeled_indices, drop_last=True)

targets = torch.cat([y for _,y in labeled_loader]).numpy()
np.unique(targets, return_counts=True)

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([3, 3, 4, 3, 2, 3, 4, 4, 3, 3]))

# Supervised

In [61]:
# Supervised
seed_everything(0)
model = ResNet18(10)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

n_epochs = 200
for i in range(n_epochs):
    train_one_epoch(model, labeled_loader, criterion, optimizer, epoch=i, device='cuda')
evaluate(model, test_loader, {}, criterion, device='cuda')

Epoch [0] [0/1] eta: 0:00:00 lr: 0.01 loss: 2.3496 (2.3496) acc1: 9.3750 (9.3750) time: 0.0274 data: 0.0060 max mem: 632
Epoch [0] Total time: 0:00:00
Epoch [1] [0/1] eta: 0:00:00 lr: 0.01 loss: 2.1804 (2.1804) acc1: 15.6250 (15.6250) time: 0.0277 data: 0.0055 max mem: 632
Epoch [1] Total time: 0:00:00
Epoch [2] [0/1] eta: 0:00:00 lr: 0.01 loss: 2.0329 (2.0329) acc1: 28.1250 (28.1250) time: 0.0284 data: 0.0063 max mem: 632
Epoch [2] Total time: 0:00:00
Epoch [3] [0/1] eta: 0:00:00 lr: 0.01 loss: 1.9223 (1.9223) acc1: 40.6250 (40.6250) time: 0.0277 data: 0.0056 max mem: 632
Epoch [3] Total time: 0:00:00
Epoch [4] [0/1] eta: 0:00:00 lr: 0.01 loss: 1.8174 (1.8174) acc1: 40.6250 (40.6250) time: 0.0286 data: 0.0065 max mem: 632
Epoch [4] Total time: 0:00:00
Epoch [5] [0/1] eta: 0:00:00 lr: 0.01 loss: 1.6492 (1.6492) acc1: 37.5000 (37.5000) time: 0.0275 data: 0.0054 max mem: 632
Epoch [5] Total time: 0:00:00
Epoch [6] [0/1] eta: 0:00:00 lr: 0.01 loss: 1.5095 (1.5095) acc1: 40.6250 (40.6250) 

{'test_acc1': 15.949999809265137,
 'test_prec': 0.1687579575989091,
 'test_loss': 6.506824493408203,
 'test_nll': 6.506824493408203,
 'test_tce': 0.6696732044219971,
 'test_mce': 0.23505638539791107}

# Semi-Supervised 

In [62]:
def freeze_bn(model):
    for name ,child in (model.named_children()):
        if name.find('BatchNorm') != -1:
            for param in child.parameters():
                param.requires_grad = True
        else:
            for param in child.parameters():
                param.requires_grad = False 

def unfreeze_bn(model):
    for name ,child in (model.named_children()):
        if name.find('BatchNorm') != -1:
            for param in child.parameters():
                param.requires_grad = True
        else:
            for param in child.parameters():
                param.requires_grad = True

In [64]:
device = 'cuda'
seed_everything(0)
def unsup_warmup(it, n_iter, unsup_warmup_pos=.4):
    return np.clip(it / (unsup_warmup_pos * n_iter), a_min=0.0, a_max=1.0)

model = ResNet18(10)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

n_iter = 1000
lambda_u = 1
labeled_iter = iter(labeled_loader)
unlabeled_iter = iter(unlabeled_loader)
model.train()
model.to(device)

for i_iter in tqdm(range(n_iter)):
    try:
        inputs_l, targets_l = next(labeled_iter)
    except StopIteration:
        labeled_iter = iter(labeled_loader)
        inputs_l, targets_l = next(labeled_iter)
    try:
        inputs_u, _ = next(unlabeled_iter)
    except StopIteration:
        unlabeled_iter = iter(unlabeled_loader)
        inputs_u, _ = next(unlabeled_iter)
    inputs_l, targets_l, inputs_u = inputs_l.to(device), targets_l.to(device), inputs_u.to(device)

    # Labeled stuff
    logits_l = model(inputs_l)
    loss_l = F.cross_entropy(logits_l, targets_l)

    # Pseudo labels
    freeze_bn(model)
    logits_u = model(inputs_u)
    unfreeze_bn(model)
    probas_u = logits_u.detach().softmax(-1)
    max_probas, pseudo_labels = probas_u.max(-1)
    mask = (max_probas > .95).float()
    loss_u = F.cross_entropy(logits_u, pseudo_labels, reduction='none') * mask
    loss_u = loss_u.mean()

    loss = loss_l + loss_u * unsup_warmup(i_iter, n_iter) * lambda_u

    # Consistency loss
    # logits_u1 = model(augmentation_fn(inputs_u))
    # logits_u2 = model(augmentation_fn(inputs_u))
    # loss_mse = F.mse_loss(logits_u1.softmax(-1), logits_u2.softmax(-1))
    # loss = loss_ce + loss_mse * unsup_warmup(i_iter, n_iter) * lambda_u

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

evaluate(model, test_loader, {}, criterion, device='cuda')

100%|██████████| 1000/1000 [00:29<00:00, 33.69it/s]


{'test_acc1': 15.380000114440918,
 'test_prec': 0.15690584267711474,
 'test_loss': 6.638526439666748,
 'test_nll': 6.638526439666748,
 'test_tce': 0.6801223754882812,
 'test_mce': 0.23900383710861206}