# Noisy student with cifar10

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets.cifar_datasets import LabeledCifar10, UnlabeledCifar10, PseudoLabeledCifar10
from randaugment import Rand_Augment

from train import train_model
from models import convnet5, convnet6, resnet6, resnet9, make_models
from criterion import cross_entropy_with_soft_target
import os

In [None]:
if os.path.exists('checkpoint'):
    os.makedirs('checkpoint')
if os.path.exists('checkpoint/teacher-student-hard'):
    os.makedirs('checkpoint/teacher-student-hard')
if os.path.exists('checkpoint/teacher-student-soft'):
    os.makedirs('checkpoint/teacher-student-soft')

## Hyperparameters

In [None]:
learning_rate = 1e-3
weight_decay = 1e-3
max_epoch = 1000
batch_size = 100
train_labeled_file = 'data/cifar10/labeled.pk'
train_unlabeled_file = 'data/cifar10/unlabeled.pk'
test_file = 'data/cifar10/test.pk'
device = 'cuda:0'

In [None]:
model_names = ['convnet5', 'convnet6', 'resnet6', 'resnet9']
criterion = nn.CrossEntropyLoss()

In [None]:
models = make_models()

In [None]:
for i, model in enumerate(models):
    parameters_num = sum([parameter.numel() for parameter in model.parameters()])
    print('{}: {}'.format(model_names[i], parameters_num))

## Search Rand augment parameter

In [None]:
#augmentations = [Rand_Augment(3, 4), Rand_Augment(3, 5), Rand_Augment(5, 7), Rand_Augment(6, 9)]
#augmentations = [Rand_Augment(2, 9), Rand_Augment(2, 10), Rand_Augment(3, 9), Rand_Augment(3, 10)]
#augmentations = [Rand_Augment(3, 10), Rand_Augment(4, 10), Rand_Augment(5, 10), Rand_Augment(6, 10)]
augmentations = [Rand_Augment(5, 10), Rand_Augment(6, 10), Rand_Augment(7, 10), Rand_Augment(8, 10)]

In [None]:
test_dataloader = DataLoader(LabeledCifar10(test_file), batch_size=batch_size, num_workers=2)
for aug in augmentations:
    labeled_dataloader = DataLoader(LabeledCifar10(train_labeled_file, transforms=aug), batch_size=batch_size, shuffle=True, num_workers=3)
    model = resnet25()
    train_model(model, labeled_dataloader, test_dataloader, criterion, optim.SGD(model.parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoch)

## Baseline

In [None]:
test_dataloader = DataLoader(LabeledCifar10(test_file), batch_size=batch_size, num_workers=2)
aug = Rand_Augment(5, 10)
models = make_models()
for i, model in enumerate(models):
    labeled_dataloader = DataLoader(LabeledCifar10(train_unlabeled_file, transforms=aug, data_count=5000), batch_size=batch_size, shuffle=True, num_workers=2)
    writer = SummaryWriter('logs/baseline/{}-{}'.format(model_names[i], 5000))
    train_model(model, labeled_dataloader, test_dataloader, criterion, optim.SGD(model.parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoch, writer)
    torch.save({'model_state_dict': model.state_dict()}, 'checkpoint/{}-{}'.format(model_names[i], 5000))

## Noisy-Student

#### Hard pseudolabels

In [None]:
criterion = nn.CrossEntropyLoss()
augs = [Rand_Augment(5, 10), Rand_Augment(5, 10), Rand_Augment(5, 10), Rand_Augment(5, 10)]
max_epoches = [1000, 1000, 1000, 1000]
models = make_models()
test_dataloader = DataLoader(LabeledCifar10(test_file), batch_size=batch_size, num_workers=2)
labeled_dataloader = DataLoader(LabeledCifar10(train_labeled_file, transforms=augs[0]), batch_size=batch_size, shuffle=True, num_workers=2)
for i in range(len(models)):
    writer = SummaryWriter('logs/teacher-student-soft/{}'.format(model_names[i]))
    if i == 0:
        train_model(models[i], labeled_dataloader, test_dataloader, criterion, optim.SGD(models[i].parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoches[i], writer)
    else:
        pseudo_labeled_dataloader = DataLoader(PseudoLabeledCifar10(train_labeled_file, train_unlabeled_file, models[i - 1], device, augs[i], soft=False), batch_size, shuffle=True, num_workers=2)
        train_model(models[i], pseudo_labeled_dataloader, test_dataloader, criterion, optim.SGD(models[i].parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoches[i], writer)
    
    torch.save({'model_state_dict': models[i].state_dict()}, 'checkpoint/teacher-student-soft/{}'.format(model_names[i]))

#### Soft pseudolabels

In [None]:
augs = [Rand_Augment(5, 10), Rand_Augment(5, 10), Rand_Augment(5, 10), Rand_Augment(5, 10)]
max_epoches = [1000, 1000, 1000, 1000] 
models = make_models()
test_dataloader = DataLoader(LabeledCifar10(test_file), batch_size=batch_size, num_workers=2)
labeled_dataloader = DataLoader(LabeledCifar10(train_labeled_file, transforms=augs[0]), batch_size=batch_size, shuffle=True, num_workers=2)
for i in range(len(models)):
    writer = SummaryWriter('logs/teacher-student-soft/{}'.format(model_names[i]))
    if i == 0:
        train_model(models[i], labeled_dataloader, test_dataloader, nn.CrossEntropyLoss(), optim.SGD(models[i].parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoches[i], writer)
    else:
        pseudo_labeled_dataloader = DataLoader(PseudoLabeledCifar10(train_labeled_file, train_unlabeled_file, models[i - 1], device, transforms=augs[i]), batch_size, shuffle=True, num_workers=2)
        train_model(models[i], pseudo_labeled_dataloader, test_dataloader, cross_entropy_with_soft_target, optim.SGD(models[i].parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoches[i], writer)
    
    torch.save({'model_state_dict': models[i].state_dict()}, 'checkpoint/teacher-student-soft/{}'.format(model_names[i]))

In [None]:
data_counts = [5000, 6000, 7000, 8000, 9000, 10000, 11000, 20000, 30000, 40000]
aug = Rand_Augment(5, 10)
test_dataloader = DataLoader(LabeledCifar10(test_file), batch_size=batch_size, num_workers=2)
for data_count in data_counts:
    model = resnet9()
    labeled_dataloader = DataLoader(LabeledCifar10(train_unlabeled_file, transforms=aug, data_count=data_count), batch_size=batch_size, shuffle=True, num_workers=2)
    writer = SummaryWriter('logs/baseline/{}-{}'.format('resnet9', data_count))
    train_model(model, labeled_dataloader, test_dataloader, criterion, optim.SGD(model.parameters(), learning_rate, momentum=0.9, weight_decay=weight_decay), device, max_epoch, writer)
    torch.save({'model_state_dict': model.state_dict()}, 'checkpoint/{}-{}'.format('resnet9', data_count))