<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Import" data-toc-modified-id="Import-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Import</a></span></li><li><span><a href="#Utils" data-toc-modified-id="Utils-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Utils</a></span></li><li><span><a href="#Load-and-prepare-the-data" data-toc-modified-id="Load-and-prepare-the-data-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Load and prepare the data</a></span></li><li><span><a href="#Create-model" data-toc-modified-id="Create-model-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Create model</a></span></li><li><span><a href="#Loss-function-&amp;-warmup" data-toc-modified-id="Loss-function-&amp;-warmup-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Loss function &amp; warmup</a></span></li><li><span><a href="#Training" data-toc-modified-id="Training-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Training</a></span><ul class="toc-item"><li><span><a href="#Normal-train" data-toc-modified-id="Normal-train-6.1"><span class="toc-item-num">6.1&nbsp;&nbsp;</span>Normal train</a></span></li><li><span><a href="#custum-acc-train" data-toc-modified-id="custum-acc-train-6.2"><span class="toc-item-num">6.2&nbsp;&nbsp;</span>custum acc train</a></span></li><li><span><a href="#add-ratio-p(m1)-<>-p(g(m2))" data-toc-modified-id="add-ratio-p(m1)-<>-p(g(m2))-6.3"><span class="toc-item-num">6.3&nbsp;&nbsp;</span>add ratio p(m1) &lt;&gt; p(g(m2))</a></span></li><li><span><a href="#add-ratio" data-toc-modified-id="add-ratio-6.4"><span class="toc-item-num">6.4&nbsp;&nbsp;</span>add ratio</a></span></li></ul></li><li><span><a href="#End" data-toc-modified-id="End-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>End</a></span></li></ul></div>

# Import

In [None]:
import os
import sys
import time
import tqdm

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import weight_norm
from torch.utils.tensorboard import SummaryWriter



from advertorch.attacks import GradientSignAttack

# Utils

In [2]:
class Metrics:
    def __init__(self, epsilon=1e-10):
        self.value = 0
        self.accumulate_value = 0
        self.count = 0
        self.epsilon = epsilon
        
    def reset(self):
        self.accumulate_value = 0
        self.count = 0
        
    def __call__(self):
        self.count += 1

        
class BinaryAccuracy(Metrics):
    def __init__(self, epsilon=1e-10):
        Metrics.__init__(self, epsilon)
        
    def __call__(self, y_pred, y_true):
        super().__call__()
        
        with torch.set_grad_enabled(False):
            y_pred = (y_pred>0.5).float()
            correct = (y_pred == y_true).float().sum()
            self.value = correct / (y_true.shape[0] * y_true.shape[1])
            
            self.accumulate_value += self.value
            return self.accumulate_value / self.count
        
        
class CategoricalAccuracy(Metrics):
    def __init__(self, epsilon=1e-10):
        Metrics.__init__(self, epsilon)
        
    def __call__(self, y_pred, y_true):
        super().__call__()
        
        with torch.set_grad_enabled(False):
            self.value = torch.mean((y_true == y_pred).float())
            
            self.accumulate_value += self.value

            return self.accumulate_value / self.count
        
class Ratio(Metrics):
    def __init__(self, epsilon=1e-10):
        Metrics.__init__(self, epsilon)
        
    def __call__(self, y_pred, y_adv_pred):
        super().__call__()
        
        results = zip(y_pred, y_adv_pred)
        results_bool = [int(r[0] != r[1]) for r in results]
        self.value = sum(results_bool) / len(results_bool) * 100
        self.accumulate_value += self.value
        
        return self.accumulate_value / self.count

In [3]:
def reset_seed(seed=1234):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
reset_seed()

In [4]:
import datetime
def get_datetime():
    now = datetime.datetime.now()
    return str(now)[:10] + "_" + str(now)[11:-7]

In [5]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Load and prepare the data

In [6]:
transform_train = transforms.Compose([
    transforms.RandomAffine(0, translate=(1/16, 1/16)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
])

dataset_root = os.path.join("..", "dataset")
train_set = torchvision.datasets.CIFAR10(dataset_root, train=True, download=True, transform=transform_train)
val_set = torchvision.datasets.CIFAR10(dataset_root, train=False, download=True, transform=transform_val)

PermissionError: [Errno 13] Permission denied: '../dataset'

In [None]:
S_idx, U_idx = [], []
classes = [[] for _ in range(10)]

for i in tqdm.tqdm(range(len(train_set))):
    data, label = train_set[i]
    classes[label].append(i)

In [None]:
for indexes in classes:
    np.random.shuffle(indexes)
    S_idx += indexes[:400]
    U_idx += indexes[400:]

# Create model

In [None]:
class repro(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        
#         self.gaussian = GaussianNoise(sigma=0.15)
        self.features = torch.nn.Sequential(
            weight_norm(nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(128, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)),
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(0.5),
            
            weight_norm(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(0.5),
            
            weight_norm(nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0)),
            nn.BatchNorm2d(512, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)),
            nn.BatchNorm2d(256, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            weight_norm(nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)),
            nn.BatchNorm2d(128, momentum=0.999),
            nn.LeakyReLU(negative_slope=0.1),
            nn.AvgPool2d(6, stride=2, padding=0),
        )
        
        self.classifier = nn.Sequential(
            weight_norm(nn.Linear(128, 10))
        )
    
        
    def forward(self, x):
        x = x.view(-1, 3, 32, 32)
#         x = self.gaussian(x)
        
        x = self.features(x)
        x = x.view(-1, 128)
#         x = nn.functional.avg_pool2d(x, kernel_size=(6,6))
#         x = x.view(-1, 128)
        x = self.classifier(x)
        
        return x


# Loss function & warmup

In [None]:
def Lsup(logit_S1, logit_S2, labels_S1, labels_S2):
    ce = nn.CrossEntropyLoss() 
    loss1 = ce(logit_S1, labels_S1)
    loss2 = ce(logit_S2, labels_S2) 
    return (loss1+loss2)

def Lcot(U_p1, U_p2):
# the Jensen-Shannon divergence between p1(x) and p2(x)
    S = nn.Softmax(dim = 1)
    LS = nn.LogSoftmax(dim = 1)
    a1 = 0.5 * (S(U_p1) + S(U_p2))
    loss1 = a1 * torch.log(a1)
    loss1 = -torch.sum(loss1)
    loss2 = S(U_p1) * LS(U_p1)
    loss2 = -torch.sum(loss2)
    loss3 = S(U_p2) * LS(U_p2)
    loss3 = -torch.sum(loss3)

    return (loss1 - 0.5 * (loss2 + loss3))/U_batch_size


def Ldiff(logit_S1, logit_S2, perturbed_logit_S1, perturbed_logit_S2, logit_U1, logit_U2, perturbed_logit_U1, perturbed_logit_U2):
    S = nn.Softmax(dim = 1)
    LS = nn.LogSoftmax(dim = 1)
    
    a = S(logit_S2) * LS(perturbed_logit_S1)
    a = torch.sum(a)

    b = S(logit_S1) * LS(perturbed_logit_S2)
    b = torch.sum(b)

    c = S(logit_U2) * LS(perturbed_logit_U1)
    c = torch.sum(c)

    d = S(logit_U1) * LS(perturbed_logit_U2)
    d = torch.sum(d)

    return -(a+b+c+d)/batch_size

In [None]:
def adjust_learning_rate(optimizer, epoch):
    """cosine scheduling"""
    epoch = epoch + 1
    lr = 0.05*(1.0 + np.cos((epoch-1)*np.pi/600))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def adjust_lamda(epoch):
    epoch = epoch + 1
    global lambda_cot
    global lambda_diff
    if epoch <= 80:
        lambda_cot = lambda_cot_max*np.exp(-5*(1-epoch/80)**2)
        lambda_diff = lambda_diff_max*np.exp(-5*(1-epoch/80)**2)
    else: 
        lambda_cot = lambda_cot_max
        lambda_diff = lambda_diff_max

# Training

In [None]:
model_func = repro

net1 = model_func()
net2 = model_func()

net1.cuda()
net2.cuda()

In [None]:
batch_size = 100
U_batch_size = int(batch_size * 46./50.)
S_batch_size = batch_size - U_batch_size
nb_batch = len(train_set) // batch_size
    
S_sampler = torch.utils.data.SubsetRandomSampler(S_idx)
U_sampler = torch.utils.data.SubsetRandomSampler(U_idx)

S1_loader = torch.utils.data.DataLoader(train_set, batch_size=S_batch_size, sampler=S_sampler)
S2_loader = torch.utils.data.DataLoader(train_set, batch_size=S_batch_size, sampler=S_sampler)
U_loader = torch.utils.data.DataLoader(train_set, batch_size=U_batch_size, sampler=U_sampler)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, num_workers=2, shuffle=True)

adv_generator_1 = GradientSignAttack( 
    net1, 
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    eps=0.02, clip_min=-np.inf, clip_max=np.inf, targeted=False
)

adv_generator_2 = GradientSignAttack( 
    net2, 
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    eps=0.02, clip_min=-np.inf, clip_max=np.inf, targeted=False
)

In [None]:
lambda_cot_max = 10
lambda_diff_max = 0.5
lambda_cot = 0.0
lambda_diff = 0.0
best_acc = 0.0 

In [None]:
optimizer = torch.optim.SGD(
    list(net1.parameters()) + list(net2.parameters()),
    momentum=0.9,
    weight_decay=1e-4,
    lr=0.05
)

## Normal train

In [None]:
def train(epoch):
    net1.train()
    net2.train()

    adjust_learning_rate(optimizer, epoch)
    adjust_lamda(epoch)
    
    total_S1 = 0
    total_S2 = 0
    total_U1 = 0
    total_U2 = 0
    train_correct_S1 = 0
    train_correct_S2 = 0
    train_correct_U1 = 0
    train_correct_U2 = 0
    running_loss = 0.0
    ls = 0.0
    lc = 0.0 
    ld = 0.0
    
    # create iterator for b1, b2, bu
    S_iter1 = iter(S1_loader)
    S_iter2 = iter(S2_loader)
    U_iter = iter(U_loader)
    print('')
    start_time = time.time()
    for i in range(nb_batch):
        inputs_S1, labels_S1 = S_iter1.next()
        inputs_S2, labels_S2 = S_iter2.next()
        inputs_U, labels_U = U_iter.next() # note that labels_U will not be used for training. 

        inputs_S1, labels_S1 = inputs_S1.cuda(), labels_S1.cuda()
        inputs_S2, labels_S2 = inputs_S2.cuda(), labels_S2.cuda()
        inputs_U = inputs_U.cuda()    


        logit_S1 = net1(inputs_S1)
        logit_S2 = net2(inputs_S2)
        logit_U1 = net1(inputs_U)
        logit_U2 = net2(inputs_U)

        _, predictions_S1 = torch.max(logit_S1, 1)
        _, predictions_S2 = torch.max(logit_S2, 1)

        # pseudo labels of U 
        _, predictions_U1 = torch.max(logit_U1, 1)
        _, predictions_U2 = torch.max(logit_U2, 1)

        # fix batchnorm
        net1.eval()
        net2.eval()
        #generate adversarial examples
        perturbed_data_S1 = adv_generator_1.perturb(inputs_S1, labels_S1)
        perturbed_data_U1 = adv_generator_1.perturb(inputs_U, predictions_U1)

        perturbed_data_S2 = adv_generator_2.perturb(inputs_S2, labels_S2)
        perturbed_data_U2 = adv_generator_2.perturb(inputs_U, predictions_U2)
        net1.train()
        net2.train()

        perturbed_logit_S1 = net1(perturbed_data_S2)
        perturbed_logit_S2 = net2(perturbed_data_S1)

        perturbed_logit_U1 = net1(perturbed_data_U2)
        perturbed_logit_U2 = net2(perturbed_data_U1)

        # zero the parameter gradients
        optimizer.zero_grad()
        net1.zero_grad()
        net2.zero_grad()

        
        Loss_sup = Lsup(logit_S1, logit_S2, labels_S1, labels_S2)
        Loss_cot = Lcot(logit_U1, logit_U2)
        Loss_diff = Ldiff(logit_S1, logit_S2, perturbed_logit_S1, perturbed_logit_S2, logit_U1, logit_U2, perturbed_logit_U1, perturbed_logit_U2)
        
        total_loss = Loss_sup + lambda_cot*Loss_cot + lambda_diff*Loss_diff
        total_loss.backward()
        optimizer.step()


        train_correct_S1 += np.sum(predictions_S1.cpu().numpy() == labels_S1.cpu().numpy())
        total_S1 += labels_S1.size(0)

        train_correct_U1 += np.sum(predictions_U1.cpu().numpy() == labels_U.cpu().numpy())
        total_U1 += labels_U.size(0)

        train_correct_S2 += np.sum(predictions_S2.cpu().numpy() == labels_S2.cpu().numpy())
        total_S2 += labels_S2.size(0)

        train_correct_U2 += np.sum(predictions_U2.cpu().numpy() == labels_U.cpu().numpy())
        total_U2 += labels_U.size(0)
        
        running_loss += total_loss.item()
        ls += Loss_sup.item()
        lc += Loss_cot.item()
        ld += Loss_diff.item()
        
        # using tensorboard to monitor loss and acc
        acc1 = (train_correct_S1+train_correct_U1) / (total_S1+total_U1)
        acc2 = (train_correct_S2+train_correct_U2) / (total_S2+total_U2)
        print("Epoch {:4}, {:3d}% \t acc: {:3.4e} {:3.4e} - loss {:3.4e} {:3.4e} {:3.4e} {:3.4e} took: {:.2f}s".format(
            epoch+1,
            int(100 * (i+1) / nb_batch),
            
            acc1, acc2,
            running_loss/(i+1), ls/(i+1), lc/(i+1), ld/(i+1), 
            time.time() - start_time,
        ), end="\r")
            
        # using tensorboard to monitor loss and acc
        tensorboard.add_scalar('train/total_loss', total_loss.item(), epoch)
        tensorboard.add_scalar('train/Lsup', Loss_sup.item(), epoch )
        tensorboard.add_scalar('train/Lcot', Loss_cot.item(), epoch )
        tensorboard.add_scalar('train/Ldiff', Loss_diff.item(), epoch )
        tensorboard.add_scalar("train/acc_1", 100. * acc1, epoch )
        tensorboard.add_scalar("train/acc_2",  100. * acc2, epoch )



def test(epoch):
    global best_acc
    net1.eval()
    net2.eval()
    correct1 = 0
    correct2 = 0
    total1 = 0
    total2 = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs = inputs.cuda()
            targets = targets.cuda()

            outputs1 = net1(inputs)
            predicted1 = outputs1.max(1)
            total1 += targets.size(0)
            correct1 += predicted1[1].eq(targets).sum().item()

            outputs2 = net2(inputs)
            predicted2 = outputs2.max(1)
            total2 += targets.size(0)
            correct2 += predicted2[1].eq(targets).sum().item()
            
    tensorboard.add_scalar('val/acc_1', 100.*correct1/total1, epoch)
    tensorboard.add_scalar('val/acc_2', 100.*correct2/total2, epoch)

    print('\nnet1 test acc: %.3f%% (%d/%d) | net2 test acc: %.3f%% (%d/%d)'
        % (100.*correct1/total1, correct1, total1, 100.*correct2/total2, correct2, total2))

    acc = ((100.*correct1/total1)+(100.*correct2/total2))/2
    if acc > best_acc:
        best_acc = acc
#         checkpoint(epoch, 'best')

tensorboard = SummaryWriter(log_dir="repro_cotraining/simple_ratio_%s_%s" % (get_datetime(), model_func.__name__), comment=model_func.__name__)
for epoch in range(0, 600):
    train(epoch)
    test(epoch)

## custum acc train

In [None]:
acc1_func = CategoricalAccuracy()
acc2_func = CategoricalAccuracy()

def train(epoch):
    net1.train()
    net2.train()

    adjust_learning_rate(optimizer, epoch)
    adjust_lamda(epoch)
    
    acc1_func.reset()
    acc2_func.reset()
    running_loss = 0.0
    ls = 0.0
    lc = 0.0 
    ld = 0.0
    
    # create iterator for b1, b2, bu
    S_iter1 = iter(S1_loader)
    S_iter2 = iter(S2_loader)
    U_iter = iter(U_loader)
    
    print('')
    start_time = time.time()
    
    for i in range(nb_batch):
        inputs_S1, labels_S1 = S_iter1.next()
        inputs_S2, labels_S2 = S_iter2.next()
        inputs_U, labels_U = U_iter.next() # note that labels_U will not be used for training. 

        inputs_S1, labels_S1 = inputs_S1.cuda(), labels_S1.cuda()
        inputs_S2, labels_S2 = inputs_S2.cuda(), labels_S2.cuda()
        inputs_U, labels_U = inputs_U.cuda(), labels_U.cuda()

        logit_S1 = net1(inputs_S1)
        logit_S2 = net2(inputs_S2)
        logit_U1 = net1(inputs_U)
        logit_U2 = net2(inputs_U)

        _, predictions_S1 = torch.max(logit_S1, 1)
        _, predictions_S2 = torch.max(logit_S2, 1)

        # pseudo labels of U 
        _, predictions_U1 = torch.max(logit_U1, 1)
        _, predictions_U2 = torch.max(logit_U2, 1)

        # fix batchnorm
        net1.eval()
        net2.eval()
        
        #generate adversarial examples
        perturbed_data_S1 = adv_generator_1.perturb(inputs_S1, labels_S1)
        perturbed_data_S2 = adv_generator_2.perturb(inputs_S2, labels_S2)
        
        perturbed_data_U1 = adv_generator_1.perturb(inputs_U, predictions_U1)
        perturbed_data_U2 = adv_generator_2.perturb(inputs_U, predictions_U2)
        
        net1.train()
        net2.train()

        perturbed_logit_S1 = net1(perturbed_data_S2)
        perturbed_logit_S2 = net2(perturbed_data_S1)

        perturbed_logit_U1 = net1(perturbed_data_U2)
        perturbed_logit_U2 = net2(perturbed_data_U1)

        # zero the parameter gradients
        optimizer.zero_grad()
        net1.zero_grad()
        net2.zero_grad()
        
        Loss_sup = Lsup(logit_S1, logit_S2, labels_S1, labels_S2)
        Loss_cot = Lcot(logit_U1, logit_U2)
        Loss_diff = Ldiff(logit_S1, logit_S2, perturbed_logit_S1, perturbed_logit_S2, logit_U1, logit_U2, perturbed_logit_U1, perturbed_logit_U2)
        
        total_loss = Loss_sup + lambda_cot*Loss_cot + lambda_diff*Loss_diff
        total_loss.backward()
        optimizer.step()

        # calc the metrics
        
        pred_S1U = torch.cat((predictions_S1, predictions_U1), 0)
        pred_S2U = torch.cat((predictions_S2, predictions_U2), 0)
        y_S1U = torch.cat((labels_S1, labels_U), 0)
        y_S2U = torch.cat((labels_S2, labels_U), 0)
        
        acc1 = acc1_func(pred_S1U, y_S1U)
        acc2 = acc2_func(pred_S2U, y_S2U)
        
        running_loss += total_loss.item()
        ls += Loss_sup.item()
        lc += Loss_cot.item()
        ld += Loss_diff.item()
        
        # using tensorboard to monitor loss and acc
        print("Epoch {:4}, {:3d}% \t acc: {:3.4e} {:3.4e} - loss {:3.4e} {:3.4e} {:3.4e} {:3.4e} took: {:.2f}s".format(
            epoch+1,
            int(100 * (i+1) / nb_batch),
            
            acc1, acc2,
            running_loss/(i+1), ls/(i+1), lc/(i+1), ld/(i+1), 
            time.time() - start_time,
        ), end="\r")
            
        # using tensorboard to monitor loss and acc
        tensorboard.add_scalar('train/total_loss', total_loss.item(), epoch)
        tensorboard.add_scalar('train/Lsup', Loss_sup.item(), epoch )
        tensorboard.add_scalar('train/Lcot', Loss_cot.item(), epoch )
        tensorboard.add_scalar('train/Ldiff', Loss_diff.item(), epoch )
        tensorboard.add_scalar("train/acc_1", 100. * acc1, epoch )
        tensorboard.add_scalar("train/acc_2",  100. * acc2, epoch )



def test(epoch):
    global best_acc
    net1.eval()
    net2.eval()
    
    acc1_func.reset()
    acc2_func.reset()
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs = inputs.cuda()
            targets = targets.cuda()

            outputs1 = net1(inputs)
            _, predicted1 = torch.max(outputs1, 1)
            acc1 = acc1_func(predicted1, targets)

            outputs2 = net2(inputs)
            _, predicted2 = torch.max(outputs2, 1)
            acc2 = acc2_func(predicted2, targets)
            
    tensorboard.add_scalar('val/acc_1', acc1, epoch)
    tensorboard.add_scalar('val/acc_2', acc2, epoch)

    print('\nnet1 test acc: %.3f%% | net2 test acc: %.3f%%'
        % (acc1, acc2))

tensorboard = SummaryWriter(log_dir="repro_cotraining/custum_acc_%s_%s" % (get_datetime(), model_func.__name__), comment=model_func.__name__)
for epoch in range(0, 600):
    train(epoch)
    test(epoch)

## add ratio p(m1) <> p(g(m2))

In [None]:
nb_epoch = 100

ratio1_p1pg1 = Ratio()
ratio2_p2pg2 = Ratio()
ratio_p1pg2 = Ratio()
ratio_p2pg1 = Ratio()

def train(epoch):
    net1.train()
    net2.train()

    adjust_learning_rate(optimizer, epoch)
    adjust_lamda(epoch)
    
    ratio1_p1pg1.reset()
    ratio2_p2pg2.reset()
    ratio_p1pg2.reset()
    ratio_p2pg1.reset()
    
    total_S1 = 0
    total_S2 = 0
    total_U1 = 0
    total_U2 = 0
    train_correct_S1 = 0
    train_correct_S2 = 0
    train_correct_U1 = 0
    train_correct_U2 = 0
    running_loss = 0.0
    ls = 0.0
    lc = 0.0 
    ld = 0.0
    
    # create iterator for b1, b2, bu
    S_iter1 = iter(S1_loader)
    S_iter2 = iter(S2_loader)
    U_iter = iter(U_loader)
    print('')
    start_time = time.time()
    for i in range(nb_batch):
        inputs_S1, labels_S1 = S_iter1.next()
        inputs_S2, labels_S2 = S_iter2.next()
        inputs_U, labels_U = U_iter.next() # note that labels_U will not be used for training. 

        inputs_S1, labels_S1 = inputs_S1.cuda(), labels_S1.cuda()
        inputs_S2, labels_S2 = inputs_S2.cuda(), labels_S2.cuda()
        inputs_U = inputs_U.cuda()    


        logit_S1 = net1(inputs_S1)
        logit_S2 = net2(inputs_S2)
        logit_U1 = net1(inputs_U)
        logit_U2 = net2(inputs_U)

        _, predictions_S1 = torch.max(logit_S1, 1)
        _, predictions_S2 = torch.max(logit_S2, 1)

        # pseudo labels of U 
        _, predictions_U1 = torch.max(logit_U1, 1)
        _, predictions_U2 = torch.max(logit_U2, 1)

        # fix batchnorm
        net1.eval()
        net2.eval()
        #generate adversarial examples
        perturbed_data_S1 = adv_generator_1.perturb(inputs_S1, labels_S1)
        perturbed_data_U1 = adv_generator_1.perturb(inputs_U, predictions_U1)

        perturbed_data_S2 = adv_generator_2.perturb(inputs_S2, labels_S2)
        perturbed_data_U2 = adv_generator_2.perturb(inputs_U, predictions_U2)
        net1.train()
        net2.train()

        perturbed_logit_S1 = net1(perturbed_data_S2)
        perturbed_logit_S2 = net2(perturbed_data_S1)

        perturbed_logit_U1 = net1(perturbed_data_U2)
        perturbed_logit_U2 = net2(perturbed_data_U1)

        # zero the parameter gradients
        optimizer.zero_grad()
        net1.zero_grad()
        net2.zero_grad()

        
        Loss_sup = Lsup(logit_S1, logit_S2, labels_S1, labels_S2)
        Loss_cot = Lcot(logit_U1, logit_U2)
        Loss_diff = Ldiff(logit_S1, logit_S2, perturbed_logit_S1, perturbed_logit_S2, logit_U1, logit_U2, perturbed_logit_U1, perturbed_logit_U2)
        
        total_loss = Loss_sup + lambda_cot*Loss_cot + lambda_diff*Loss_diff
        total_loss.backward()
        optimizer.step()

        _, prediction_perturbed_U1 = torch.max(perturbed_logit_U1, 1)
        _, prediction_perturbed_U2 = torch.max(perturbed_logit_U2, 1)
        ratio1 = ratio1_p1pg1(predictions_U1, prediction_perturbed_U1)
        ratio2 = ratio2_p2pg2(predictions_U2, prediction_perturbed_U2)
        ratio3 = ratio_p1pg2(predictions_U1, prediction_perturbed_U2)
        ratio4 = ratio_p2pg1(predictions_U2, prediction_perturbed_U1)
        
        train_correct_S1 += np.sum(predictions_S1.cpu().numpy() == labels_S1.cpu().numpy())
        total_S1 += labels_S1.size(0)

        train_correct_U1 += np.sum(predictions_U1.cpu().numpy() == labels_U.cpu().numpy())
        total_U1 += labels_U.size(0)

        train_correct_S2 += np.sum(predictions_S2.cpu().numpy() == labels_S2.cpu().numpy())
        total_S2 += labels_S2.size(0)

        train_correct_U2 += np.sum(predictions_U2.cpu().numpy() == labels_U.cpu().numpy())
        total_U2 += labels_U.size(0)
        
        running_loss += total_loss.item()
        ls += Loss_sup.item()
        lc += Loss_cot.item()
        ld += Loss_diff.item()
        
        # using tensorboard to monitor loss and acc
        acc1 = (train_correct_S1+train_correct_U1) / (total_S1+total_U1)
        acc2 = (train_correct_S2+train_correct_U2) / (total_S2+total_U2)
        print("Epoch {:4}, {:3d}% \t acc: {:3.4e} {:3.4e} - loss {:3.4e} {:3.4e} {:3.4e} {:3.4e} took: {:.2f}s".format(
            epoch+1,
            int(100 * (i+1) / nb_batch),
            
            acc1, acc2,
            running_loss/(i+1), ls/(i+1), lc/(i+1), ld/(i+1), 
            time.time() - start_time,
        ), end="\r")
            
        # using tensorboard to monitor loss and acc
        tensorboard.add_scalar('train/total_loss', total_loss.item(), epoch)
        tensorboard.add_scalar('train/Lsup', Loss_sup.item(), epoch )
        tensorboard.add_scalar('train/Lcot', Loss_cot.item(), epoch )
        tensorboard.add_scalar('train/Ldiff', Loss_diff.item(), epoch )
        tensorboard.add_scalar("train/acc_1", 100. * acc1, epoch )
        tensorboard.add_scalar("train/acc_2",  100. * acc2, epoch )
        tensorboard.add_scalar("train/ratio1", ratio1, epoch )
        tensorboard.add_scalar("train/ratio2", ratio2, epoch )
        tensorboard.add_scalar("train/ratio_p1pg2", ratio3, epoch )
        tensorboard.add_scalar("train/ratio_p2pg1", ratio4, epoch )

def test(epoch):
    global best_acc
    net1.eval()
    net2.eval()
    correct1 = 0
    correct2 = 0
    total1 = 0
    total2 = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs = inputs.cuda()
            targets = targets.cuda()

            outputs1 = net1(inputs)
            predicted1 = outputs1.max(1)
            total1 += targets.size(0)
            correct1 += predicted1[1].eq(targets).sum().item()

            outputs2 = net2(inputs)
            predicted2 = outputs2.max(1)
            total2 += targets.size(0)
            correct2 += predicted2[1].eq(targets).sum().item()
            
    tensorboard.add_scalar('val/acc_1', 100.*correct1/total1, epoch)
    tensorboard.add_scalar('val/acc_2', 100.*correct2/total2, epoch)

    print('\nnet1 test acc: %.3f%% (%d/%d) | net2 test acc: %.3f%% (%d/%d)'
        % (100.*correct1/total1, correct1, total1, 100.*correct2/total2, correct2, total2))

    acc = ((100.*correct1/total1)+(100.*correct2/total2))/2

tensorboard = SummaryWriter(log_dir="repro_cotraining/ratio_p1pg2_%s_%s" % (get_datetime(), model_func.__name__), comment=model_func.__name__)
for epoch in range(0, nb_epoch):
    print("Epoch     ,    % \t acc: m1_acc    m2_acc    - loss sup       diff      cot      ")
    train(epoch)
    test(epoch)

## add ratio

In [None]:
ratio1_func = Ratio()
ratio2_func = Ratio()

def train(epoch):
    net1.train()
    net2.train()

    adjust_learning_rate(optimizer, epoch)
    adjust_lamda(epoch)
    
    ratio1_func.reset()
    ratio2_func.reset()
    
    total_S1 = 0
    total_S2 = 0
    total_U1 = 0
    total_U2 = 0
    train_correct_S1 = 0
    train_correct_S2 = 0
    train_correct_U1 = 0
    train_correct_U2 = 0
    running_loss = 0.0
    ls = 0.0
    lc = 0.0 
    ld = 0.0
    
    # create iterator for b1, b2, bu
    S_iter1 = iter(S1_loader)
    S_iter2 = iter(S2_loader)
    U_iter = iter(U_loader)
    print('')
    start_time = time.time()
    for i in range(nb_batch):
        inputs_S1, labels_S1 = S_iter1.next()
        inputs_S2, labels_S2 = S_iter2.next()
        inputs_U, labels_U = U_iter.next() # note that labels_U will not be used for training. 

        inputs_S1, labels_S1 = inputs_S1.cuda(), labels_S1.cuda()
        inputs_S2, labels_S2 = inputs_S2.cuda(), labels_S2.cuda()
        inputs_U = inputs_U.cuda()    


        logit_S1 = net1(inputs_S1)
        logit_S2 = net2(inputs_S2)
        logit_U1 = net1(inputs_U)
        logit_U2 = net2(inputs_U)

        _, predictions_S1 = torch.max(logit_S1, 1)
        _, predictions_S2 = torch.max(logit_S2, 1)

        # pseudo labels of U 
        _, predictions_U1 = torch.max(logit_U1, 1)
        _, predictions_U2 = torch.max(logit_U2, 1)

        # fix batchnorm
        net1.eval()
        net2.eval()
        #generate adversarial examples
        perturbed_data_S1 = adv_generator_1.perturb(inputs_S1, labels_S1)
        perturbed_data_U1 = adv_generator_1.perturb(inputs_U, predictions_U1)

        perturbed_data_S2 = adv_generator_2.perturb(inputs_S2, labels_S2)
        perturbed_data_U2 = adv_generator_2.perturb(inputs_U, predictions_U2)
        net1.train()
        net2.train()

        perturbed_logit_S1 = net1(perturbed_data_S2)
        perturbed_logit_S2 = net2(perturbed_data_S1)

        perturbed_logit_U1 = net1(perturbed_data_U2)
        perturbed_logit_U2 = net2(perturbed_data_U1)

        # zero the parameter gradients
        optimizer.zero_grad()
        net1.zero_grad()
        net2.zero_grad()

        
        Loss_sup = Lsup(logit_S1, logit_S2, labels_S1, labels_S2)
        Loss_cot = Lcot(logit_U1, logit_U2)
        Loss_diff = Ldiff(logit_S1, logit_S2, perturbed_logit_S1, perturbed_logit_S2, logit_U1, logit_U2, perturbed_logit_U1, perturbed_logit_U2)
        
        total_loss = Loss_sup + lambda_cot*Loss_cot + lambda_diff*Loss_diff
        total_loss.backward()
        optimizer.step()

        _, prediction_perturbed_U1 = torch.max(perturbed_logit_U1, 1)
        _, prediction_perturbed_U2 = torch.max(perturbed_logit_U2, 1)
        ratio1 = ratio1_func(predictions_U1, prediction_perturbed_U1)
        ratio2 = ratio2_func(predictions_U2, prediction_perturbed_U2)
        
        train_correct_S1 += np.sum(predictions_S1.cpu().numpy() == labels_S1.cpu().numpy())
        total_S1 += labels_S1.size(0)

        train_correct_U1 += np.sum(predictions_U1.cpu().numpy() == labels_U.cpu().numpy())
        total_U1 += labels_U.size(0)

        train_correct_S2 += np.sum(predictions_S2.cpu().numpy() == labels_S2.cpu().numpy())
        total_S2 += labels_S2.size(0)

        train_correct_U2 += np.sum(predictions_U2.cpu().numpy() == labels_U.cpu().numpy())
        total_U2 += labels_U.size(0)
        
        running_loss += total_loss.item()
        ls += Loss_sup.item()
        lc += Loss_cot.item()
        ld += Loss_diff.item()
        
        # using tensorboard to monitor loss and acc
        acc1 = (train_correct_S1+train_correct_U1) / (total_S1+total_U1)
        acc2 = (train_correct_S2+train_correct_U2) / (total_S2+total_U2)
        print("Epoch {:4}, {:3d}% \t acc: {:3.4e} {:3.4e} - loss {:3.4e} {:3.4e} {:3.4e} {:3.4e} took: {:.2f}s".format(
            epoch+1,
            int(100 * (i+1) / nb_batch),
            
            acc1, acc2,
            running_loss/(i+1), ls/(i+1), lc/(i+1), ld/(i+1), 
            time.time() - start_time,
        ), end="\r")
            
        # using tensorboard to monitor loss and acc
        tensorboard.add_scalar('train/total_loss', total_loss.item(), epoch)
        tensorboard.add_scalar('train/Lsup', Loss_sup.item(), epoch )
        tensorboard.add_scalar('train/Lcot', Loss_cot.item(), epoch )
        tensorboard.add_scalar('train/Ldiff', Loss_diff.item(), epoch )
        tensorboard.add_scalar("train/acc_1", 100. * acc1, epoch )
        tensorboard.add_scalar("train/acc_2",  100. * acc2, epoch )
        tensorboard.add_scalar("train/ratio1", ratio1, epoch )
        tensorboard.add_scalar("train/ratio2", ratio2, epoch )

def test(epoch):
    global best_acc
    net1.eval()
    net2.eval()
    correct1 = 0
    correct2 = 0
    total1 = 0
    total2 = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs = inputs.cuda()
            targets = targets.cuda()

            outputs1 = net1(inputs)
            predicted1 = outputs1.max(1)
            total1 += targets.size(0)
            correct1 += predicted1[1].eq(targets).sum().item()

            outputs2 = net2(inputs)
            predicted2 = outputs2.max(1)
            total2 += targets.size(0)
            correct2 += predicted2[1].eq(targets).sum().item()
            
    tensorboard.add_scalar('val/acc_1', 100.*correct1/total1, epoch)
    tensorboard.add_scalar('val/acc_2', 100.*correct2/total2, epoch)

    print('\nnet1 test acc: %.3f%% (%d/%d) | net2 test acc: %.3f%% (%d/%d)'
        % (100.*correct1/total1, correct1, total1, 100.*correct2/total2, correct2, total2))

    acc = ((100.*correct1/total1)+(100.*correct2/total2))/2
    if acc > best_acc:
        best_acc = acc
#         checkpoint(epoch, 'best')

tensorboard = SummaryWriter(log_dir="repro_cotraining/simple_%s_%s" % (get_datetime(), model_func.__name__), comment=model_func.__name__)
for epoch in range(0, 600):
    print("Epoch     ,    % \t acc: m1_acc    m2_acc    - loss sup       diff      cot      ")
    train(epoch)
    test(epoch)

# End