<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Import" data-toc-modified-id="Import-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Import</a></span></li><li><span><a href="#Utils" data-toc-modified-id="Utils-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Utils</a></span></li><li><span><a href="#Load-and-prepare-the-data" data-toc-modified-id="Load-and-prepare-the-data-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Load and prepare the data</a></span></li><li><span><a href="#Create-model" data-toc-modified-id="Create-model-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Create model</a></span></li><li><span><a href="#Loss-function-&amp;-warmup" data-toc-modified-id="Loss-function-&amp;-warmup-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Loss function &amp; warmup</a></span></li><li><span><a href="#Training" data-toc-modified-id="Training-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Training</a></span><ul class="toc-item"><li><span><a href="#Normal-train" data-toc-modified-id="Normal-train-6.1"><span class="toc-item-num">6.1&nbsp;&nbsp;</span>Normal train</a></span></li></ul></li><li><span><a href="#End" data-toc-modified-id="End-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>End</a></span></li></ul></div>

# Import

In [1]:
import os
import sys
import time
import tqdm

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import weight_norm
from torch.utils.tensorboard import SummaryWriter



from advertorch.attacks import GradientSignAttack

# Utils

In [2]:
class Metrics:
    def __init__(self, epsilon=1e-10):
        self.value = 0
        self.accumulate_value = 0
        self.count = 0
        self.epsilon = epsilon
        
    def reset(self):
        self.accumulate_value = 0
        self.count = 0
        
    def __call__(self):
        self.count += 1

        
class BinaryAccuracy(Metrics):
    def __init__(self, epsilon=1e-10):
        Metrics.__init__(self, epsilon)
        
    def __call__(self, y_pred, y_true):
        super().__call__()
        
        with torch.set_grad_enabled(False):
            y_pred = (y_pred>0.5).float()
            correct = (y_pred == y_true).float().sum()
            self.value = correct/ (y_true.shape[0] * y_true.shape[1])
            
            self.accumulate_value += self.value
            return self.accumulate_value / self.count
        
        
class CategoricalAccuracy(Metrics):
    def __init__(self, epsilon=1e-10):
        Metrics.__init__(self, epsilon)
        
    def __call__(self, y_pred, y_true, maxi: bool = True):
        super().__call__()
        
        with torch.set_grad_enabled(False):
            self.value = torch.mean((y_true == y_pred).float())
            self.accumulate_value += self.value

            return self.accumulate_value / self.count

        
class Ratio(Metrics):
    def __init__(self, epsilon=1e-10):
        Metrics.__init__(self, epsilon)
        
    def __call__(self, y_pred, y_adv_pred):
        super().__call__()
        
        results = zip(y_pred, y_adv_pred)
        results_bool = [int(r[0] != r[1]) for r in results]
        self.value = sum(results_bool) / len(results_bool) * 100
        self.accumulate_value += self.value
        
        return self.accumulate_value / self.count

In [3]:
def reset_seed(seed=1234):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
reset_seed()

In [4]:
import datetime
def get_datetime():
    now = datetime.datetime.now()
    return str(now)[:10] + "_" + str(now)[11:-7]

In [5]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Load and prepare the data

In [7]:
transform_train = transforms.Compose([
    transforms.RandomAffine(0, translate=(1/16, 1/16)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
])

dataset_root = os.path.join("..", "dataset")
train_set = torchvision.datasets.CIFAR10(dataset_root, train=True, download=True, transform=transform_train)
val_set = torchvision.datasets.CIFAR10(dataset_root, train=False, download=True, transform=transform_val)

PermissionError: [Errno 13] Permission denied: '../dataset'

In [None]:
S_idx, U_idx = [], []
classes = [[] for _ in range(10)]

for i in tqdm.tqdm(range(len(train_set))):
    data, label = train_set[i]
    classes[label].append(i)

In [None]:
for indexes in classes:
    np.random.shuffle(indexes)
    S_idx += indexes[:400]
    U_idx += indexes[400:]

# Create model

In [None]:
class repro(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        
#         self.gaussian = GaussianNoise(sigma=0.15)
        self.features = torch.nn.Sequential(
            weight_norm(nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(128, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)),
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(0.5),
            
            weight_norm(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(0.5),
            
            weight_norm(nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0)),
            nn.BatchNorm2d(512, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)),
            nn.BatchNorm2d(128, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            nn.AvgPool2d(6, stride=2, padding=0),
        )
        
        self.classifier = nn.Sequential(
            weight_norm(nn.Linear(128, 10))
        )
    
        
    def forward(self, x):
        x = x.view(-1, 3, 32, 32)
#         x = self.gaussian(x)
        
        x = self.features(x)
        x = x.view(-1, 128)
#         x = nn.functional.avg_pool2d(x, kernel_size=(6,6))
#         x = x.view(-1, 128)
        x = self.classifier(x)
        
        return x

# Loss function & warmup

In [None]:
def Lsup(logit_S1, logit_S2, labels_S1, labels_S2):
    ce = nn.CrossEntropyLoss() 
    loss1 = ce(logit_S1, labels_S1)
    loss2 = ce(logit_S2, labels_S2) 
    return (loss1+loss2)

def Lcot(U_p1, U_p2):
# the Jensen-Shannon divergence between p1(x) and p2(x)
    S = nn.Softmax(dim = 1)
    LS = nn.LogSoftmax(dim = 1)
    a1 = 0.5 * (S(U_p1) + S(U_p2))
    loss1 = a1 * torch.log(a1)
    loss1 = -torch.sum(loss1)
    loss2 = S(U_p1) * LS(U_p1)
    loss2 = -torch.sum(loss2)
    loss3 = S(U_p2) * LS(U_p2)
    loss3 = -torch.sum(loss3)

    return (loss1 - 0.5 * (loss2 + loss3))/U_batch_size


def Ldiff(logit_S1, logit_S2, perturbed_logit_S1, perturbed_logit_S2, logit_U1, logit_U2, perturbed_logit_U1, perturbed_logit_U2):
    S = nn.Softmax(dim = 1)
    LS = nn.LogSoftmax(dim = 1)
    
    a = S(logit_S2) * LS(perturbed_logit_S1)
    a = torch.sum(a)

    b = S(logit_S1) * LS(perturbed_logit_S2)
    b = torch.sum(b)

    c = S(logit_U2) * LS(perturbed_logit_U1)
    c = torch.sum(c)

    d = S(logit_U1) * LS(perturbed_logit_U2)
    d = torch.sum(d)

    return -(a+b+c+d)/batch_size

In [None]:
def adjust_learning_rate(optimizer, epoch):
    """cosine scheduling"""
    epoch = epoch + 1
    lr = 0.05*(1.0 + np.cos((epoch-1)*np.pi/600))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def adjust_lamda(epoch):
    epoch = epoch + 1
    global lambda_cot
    global lambda_diff
    if epoch <= 80:
        lambda_cot = lambda_cot_max*np.exp(-5*(1-epoch/80)**2)
        lambda_diff = lambda_diff_max*np.exp(-5*(1-epoch/80)**2)
    else: 
        lambda_cot = lambda_cot_max
        lambda_diff = lambda_diff_max

# Training

In [None]:
nb_view = 4

if nb_view % 2 != 0:
    raise AssertionError("Nb view must be a multiple of 2")

model_func = repro

models = [model_func() for _ in range(nb_view)]

for m in models:
    m.cuda()

In [None]:
batch_size = 100
U_batch_size = int(batch_size * 46./50.)
S_batch_size = batch_size - U_batch_size
nb_batch = len(train_set) // batch_size

tensorboard = SummaryWriter(log_dir="repro_cotraining/%d_views_%s_%s" % (nb_view, get_datetime(), model_func.__name__), comment=model_func.__name__)
    
S_sampler = torch.utils.data.SubsetRandomSampler(S_idx)
U_sampler = torch.utils.data.SubsetRandomSampler(U_idx)

S_loaders = []
for _ in range(nb_view):
    S_loaders.append(torch.utils.data.DataLoader(train_set, batch_size=S_batch_size, num_workers=2, sampler=S_sampler))
U_loader = torch.utils.data.DataLoader(train_set, batch_size=U_batch_size, num_workers=2, sampler=U_sampler)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, num_workers=2, shuffle=True)

adv_generators = []
for i in range(nb_view):
    adv_generators.append(
        GradientSignAttack( 
            models[i],
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            eps=0.02, clip_min=-np.inf, clip_max=np.inf, targeted=False
        )
    )

In [None]:
lambda_cot_max = 10
lambda_diff_max = 0.5
lambda_cot = 0.0
lambda_diff = 0.0
best_acc = 0.0 

In [None]:
parameters = list(models[0].parameters())
for i in range(1, nb_view):
    parameters += list(models[i].parameters())
    
optimizer = torch.optim.SGD(
    parameters,
    momentum=0.9,
    weight_decay=1e-4,
    lr=0.05
)

## Normal train

In [None]:
accuracies_func = [CategoricalAccuracy() for _ in range(nb_view)]
ratios_func = [Ratio() for _ in range(nb_view)]

def train(epoch):
    # Real Multi-view implementation
    # Pair are randomly chose at each iteration
    if nb_view >= 4:
        nb_tuple = nb_view // 2
        m_index = np.asarray(range(nb_view))
        np.random.shuffle(m_index)
        model_couples = np.split(m_index, nb_tuple)
    else:
        model_couples = ((0, 1),)
            
    for m in models:
        m.train()

    adjust_learning_rate(optimizer, epoch)
    adjust_lamda(epoch)
    
    # reset metrics
    for i in range(nb_view):
        accuracies_func[i].reset()
        ratios_func[i].reset()
        
    running_loss = 0.0
    ls = 0.0
    lc = 0.0 
    ld = 0.0
    
    # create iterator for b1, b2, bu
    S_iterators = [iter(s_loader) for s_loader in S_loaders]
    U_iter = iter(U_loader)
    
    print('')
    start_time = time.time()
    
    for b in range(nb_batch):
        X_S, y_S = [], []
        for s_iterator in S_iterators:
            X, y = s_iterator.next()
            X, y = X.cuda(), y.cuda()
            X_S.append(X)
            y_S.append(y)

        X_U, y_U = U_iter.next() # note that labels_U will not be used for training. 
        X_U = X_U.cuda()    
        y_U = y_U.cuda()

        logits_S = [models[i](X_S[i]) for i in range(nb_view)]
        logits_U = [models[i](X_U) for i in range(nb_view)]
    
        predictions_S, predictions_U = [], []
        for i in range(nb_view):
            _, p_s = torch.max(logits_S[i], 1)
            _, p_u = torch.max(logits_U[i], 1)
            
            predictions_S.append(p_s)
            predictions_U.append(p_u)

        # fix batchnorm
        for m in models:
            m.eval()
    
        #generate adversarial examples
        adv_S, adv_U = [], []
        for i in range(nb_view):
            adv_S.append(adv_generators[i].perturb(X_S[i], y_S[i]))
            adv_U.append(adv_generators[i].perturb(X_U, predictions_U[i]))
        
        for m in models:
            m.train()

        # Prediction on the adversarial exemple MUST use the couple
        # randomly chosen at the beginning of the iteration
        adv_logits_S, adv_logits_U = [None] * nb_view, [None] * nb_view
        print("predict adv")
        for couple in model_couples:
            print("couple: ", couple)
            adv_logits_S[couple[0]] = models[couple[0]](adv_S[couple[1]])
            adv_logits_S[couple[1]] = models[couple[1]](adv_S[couple[0]])
            
            adv_logits_U[couple[0]] = models[couple[0]](adv_U[couple[1]])
            adv_logits_U[couple[1]] = models[couple[1]](adv_U[couple[0]])

        # zero the parameter gradients
        optimizer.zero_grad()
        for m in models:
            m.zero_grad()
        
        Loss_sup = 0
        Loss_cot = 0
        Loss_diff = 0
        for couple in model_couples:
            Loss_sup += Lsup(
                logits_S[couple[0]], logits_S[couple[1]],
                y_S[couple[0]], y_S[couple[1]]
            )
            
            Loss_cot += Lcot(logits_U[couple[0]], logits_U[couple[1]])
            
            Loss_diff += Ldiff(
                logits_S[couple[0]], logits_S[couple[1]],
                adv_logits_S[couple[0]], adv_logits_S[couple[1]],
                logits_U[couple[0]], logits_U[couple[1]],
                adv_logits_U[couple[0]], adv_logits_U[couple[1]]
            )
            
        total_loss = Loss_sup + lambda_cot*Loss_cot + lambda_diff*Loss_diff
        
        total_loss.backward()
        optimizer.step()
        
        # Calc the metrics
        accuracies, ratios = [], []
        for i in range(nb_view):
            # accuracy
            X_SU = torch.cat((predictions_S[i], predictions_U[i]), 0)
            y_SU = torch.cat((y_S[i], y_U), 0)
            accuracies.append(accuracies_func[i](X_SU, y_SU))
            
            # ratio
            _, prediction_adv_U = torch.max(adv_logits_U[i], 1)
            ratios.append(ratios_func[i](predictions_U[i], prediction_adv_U) )
        
        running_loss += total_loss.item()
        ls += Loss_sup.item()
        lc += Loss_cot.item()
        ld += Loss_diff.item()
        
        # using tensorboard to monitor loss and acc
        msg = "Epoch {:4}, {:3d}% \t acc: " + "{:3.4e} " * nb_view + " - loss {:3.4e} {:3.4e} {:3.4e} {:3.4e} took: {:.2f}s"
        msg = msg.format(
            epoch+1,
            int(100 * (b+1) / nb_batch),
            
            *accuracies,
            running_loss/(b+1), ls/(b+1), lc/(b+1), ld/(b+1), 
            time.time() - start_time,
        )
        print(msg, end="\r")
            
        # using tensorboard to monitor loss and acc
        tensorboard.add_scalar('train/total_loss', total_loss.item(), epoch)
        tensorboard.add_scalar('train/Lsup', Loss_sup.item(), epoch )
        tensorboard.add_scalar('train/Lcot', Loss_cot.item(), epoch )
        tensorboard.add_scalar('train/Ldiff', Loss_diff.item(), epoch )
        for i in range(nb_view):
            tensorboard.add_scalar("train/acc_%d" % (i+1), 100. * accuracies[i], epoch )
            tensorboard.add_scalar("train/ratio_%d" % (i+1), ratios[i], epoch)

def test(epoch):
    
    for i in range(nb_view):
        models[i].eval()
        accuracies_func[i].reset()
        
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            X_val = inputs.cuda()
            y_val = targets.cuda()

            predictions = []
            for m in models:
                predictions.append(m(X_val))
                
            accuracies = []
            for i in range(nb_view):
                _, pred_class = torch.max(predictions[i], 1)
                accuracies.append(accuracies_func[i](pred_class, y_val))
            
        for i in range(nb_view):
            tensorboard.add_scalar('val/acc_%d' % (i+1) , accuracies[i], epoch)

        msg = "Epoch {:4}, {:3d}% \t val acc: " + "{:3.4e} " * nb_view
        msg = msg.format(
            epoch+1,
            int(100 * (i+1) / nb_batch),
            
            *accuracies,
        )
        print("")
        print(msg)

for epoch in range(0, 600):
    train(epoch)
    test(epoch)

# End