## Metrics Classes

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Proofpoint

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Proofpoint


In [None]:
!pip install mlflow --quiet

[K     |████████████████████████████████| 17.0 MB 30.9 MB/s 
[K     |████████████████████████████████| 209 kB 92.8 MB/s 
[K     |████████████████████████████████| 79 kB 3.6 MB/s 
[K     |████████████████████████████████| 182 kB 77.8 MB/s 
[K     |████████████████████████████████| 77 kB 7.0 MB/s 
[K     |████████████████████████████████| 147 kB 84.6 MB/s 
[K     |████████████████████████████████| 78 kB 7.7 MB/s 
[K     |████████████████████████████████| 62 kB 1.5 MB/s 
[K     |████████████████████████████████| 140 kB 102.3 MB/s 
[K     |████████████████████████████████| 55 kB 3.9 MB/s 
[K     |████████████████████████████████| 63 kB 1.8 MB/s 
[?25h  Building wheel for databricks-cli (setup.py) ... [?25l[?25hdone


In [None]:
import torch
from torchsummary import summary
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import mlflow
import mlflow.sklearn
import sys
import os


In [None]:
"""
This script contains functions and classes to compute churn metrics
UNTESTED
"""

import numpy as np
import torch
from typing import Union

class ChurnMetric:
    """
    Super class for metrics. 
    """
    def __init__(self, tensor_type="numpy") -> None:
        tensor_types = {"numpy": np.ndarray, "torch": torch.Tensor}
        if tensor_type not in tensor_types:
            raise NotImplementedError("Unknown object type")
        self.tensor_type = tensor_types[tensor_type]
    
    def call_sanitize_inputs(self, **preds):
        """
        preds can be true labels as well
        Shape of tensors need to be the same.
        Tensor dim len must not be >2
        """
        for p in preds:
            if not isinstance(preds[p], self.tensor_type):
                raise TypeError(f"{p} is not an instance of {str(self.tensor_type)}")
            if len(preds[p].shape) > 2:
                raise ValueError(f"Too many dims in {p}")
        shapes = set([preds[p].shape for p in preds])
        if len(shapes) > 1: #TODO extend to force only first dim to be the same
            raise ValueError(f"shape mismatch. shapes of preds must be same")
        

In [None]:
class Churn(ChurnMetric):
    """
    Simple Churn. Calculates number of classification disagreements, along axis:0. 
    Will take argmax if multiple columns.
    """
    def __init__(self, tensor_type="numpy", output_mode="proportion") -> None:
        super().__init__(tensor_type=tensor_type)
        if output_mode not in {"proportion", "count"}:
            raise ValueError("Unknown output_mode")
        self.output_mode = output_mode
    
    def __call__(self, predA: Union[np.ndarray, torch.Tensor], predB:Union[np.ndarray, torch.Tensor]) -> None:
        self.call_sanitize_inputs(predA=predA, predB=predB)
        
        if len(predA.shape) > 1:
            predA = predA.argmax(1)
            predB = predB.argmax(1)
        
        churn = sum(predA!=predB)

        if self.output_mode == "proportion":
            return churn / predA.shape[0]
        if self.output_mode == "count":
            return churn


In [None]:
class WinLossRatio(ChurnMetric):
    """
    Lateral Churns are not loss
    """
    def __init__(self, tensor_type="numpy") -> None:
        super().__init__(tensor_type)
    
    def __call__(self, true_labels, pred_teacher, pred_student):
        self.call_sanitize_inputs(true_labels=true_labels, pred_teacher=pred_teacher, pred_student=pred_student)

        if len(pred_teacher.shape) > 1:
            pred_teacher = pred_teacher.argmax(1)
            pred_student = pred_student.argmax(1)
            true_labels = true_labels.argmax(1)

        pred_teacher = pred_teacher == true_labels
        pred_student = pred_student == true_labels
        wins = sum(pred_student > pred_teacher)
        losses = sum(pred_student < pred_teacher)

        return wins , losses


In [None]:

class ChurnRatio(ChurnMetric):
    def __init__(self, tensor_type="numpy") -> None:
        super().__init__(tensor_type)
    
    def __call__(self, pred_teacher, pred_student, pred_control):
        self.call_sanitize_inputs(pred_teacher, pred_student, pred_control)
        
        pred_teacher = pred_teacher.argmax(1)
        pred_student = pred_student.argmax(1)
        pred_control = pred_control.argmax(1)

        churnratio = sum(pred_student!=pred_teacher) / sum(pred_control!=pred_teacher)
        return churnratio



In [None]:
class GoodBadChurn(ChurnMetric):
    """
    lateral churn is bad
    """
    def __init__(self, tensor_type="numpy", mode=None, output_mode="proportion") -> None:
        super().__init__(tensor_type)
        if mode is None or mode not in {"good", "bad"}:
            raise ValueError("Please specify mode as good or bad")
        self.mode = mode
        if output_mode not in {"proportion", "count"}:
            raise ValueError("Unknown output_mode")
        self.output_mode = output_mode

    def __call__(self, true_labels, pred_teacher, pred_student):
        self.call_sanitize_inputs(true_labels=true_labels, pred_teacher=pred_teacher, pred_student=pred_student)
        
        pred_teacher = pred_teacher.argmax(1)
        pred_student = pred_student.argmax(1)
        true_labels = true_labels.argmax(1)

        
        if self.mode == "good":
            pred_teacher = pred_teacher == true_labels
            pred_student = pred_student == true_labels    
            churn = sum(pred_student > pred_teacher)
        elif self.mode == "bad":
            churn = sum((pred_student < pred_teacher))

        if self.output_mode == "proportion":
            return churn / pred_teacher.shape[0]
        if self.output_mode == "count":
            return churn

### Label Modification Methods

In [None]:
#y_true when fed outside as argument to call should be one_hot encoded
class ChurnMethod():
    """
    Place Holder class that is other label modification methodology classes shold inherit from.
    Does nothing right now but can be modified to bring some functionality to all methods at once
    """
    def __init__(self):
        pass

class Distillation(ChurnMethod):
    def __init__(self, teacher, lamda=0.5):
        super().__init__()
        self.lamda = lamda
        self.teacher = teacher

    def __call__(self, X, y_true): #call an instance of the class to run this function. y_true needs to be one_hot encoded in this implementation
        with torch.no_grad(): #To ensure that pytorch is not creating gradients on the teacher model. i.e, all learning happens in student and double-ensuring teacher is static.
            teacher_pred = self.teacher(X) # The last layer of teacher model is just 10 neurons, all of which can take any real value
            teacher_label = nn.functional.softmax(self.teacher(X), dim=1) #convert output of model to softmax, getting probabilities for each class
            new_label = teacher_label * self.lamda + y_true * (1 - self.lamda) #distillation equation
            return new_label

class Anchor_RCP(ChurnMethod):
    def __init__(self, teacher, alpha=0.5, epsilon=0.5): #two hyperparams of anchor
        super().__init__()
        self.alpha = alpha
        self.epsilon = epsilon
        self.teacher = teacher

    def __call__(self, X, y_true):
        with torch.no_grad(): 
            teacher_pred = self.teacher(X)
            teacher_label = nn.functional.softmax(self.teacher(X), dim=1)
            correct_indices = y_true.argmax(1) == teacher_label.argmax(1) #see which indices are correct pred
            device = X.get_device() #in torch, all operations between models and variables can only be done when they are on the same device (cpu, gpu). here we get device of X. In general if you see something.cuda() being done somewhere, that sends the object to gpu
            new_label = torch.zeros(y_true.shape).to(device) #new variables are by default created on cpu. this line will transfer the new_label to whichever device X was on
            new_label[~correct_indices,:] = self.epsilon * y_true[~correct_indices,:] #all incorrect labels and are given the epsilon treatment from the rcp equation
            new_label[correct_indices,:] = teacher_label[correct_indices,:] * self.alpha + y_true[correct_indices,:] * (1 - self.alpha) #correct preds are given the alpha treatment from the rcp equation
            return new_label
        


In [None]:
size, padding = 32, 4
transform = transforms.Compose(
    [
        transforms.RandomCrop(size, padding), transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

trainvalset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

split1 = torch.utils.data.random_split(trainvalset, [40000, 10000])
split2 = torch.utils.data.random_split(split1[0], [30000, 10000])
newtrainset = split1[0]
valset = split1[1]
oldtrainset = split2[0]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
from torch.optim import lr_scheduler
old_train_loader = torch.utils.data.DataLoader(
        oldtrainset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        pin_memory=True
)
new_train_loader = torch.utils.data.DataLoader(
        newtrainset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
        valset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        pin_memory=True
)

# batchsize = 128
# SGD lr=0.1
# lr_scheduler
# random_crop, flipping, rotating



In [None]:
class newmodel(nn.Module):
    def __init__(self, resnet) -> None:
        super().__init__()
        self.resnet = resnet
        self.final_layer = nn.Linear(1000, 10)

    def forward(self, x):
        x = nn.functional.relu(self.resnet(x))
        x = self.final_layer(x)
        return x

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target):
    """Computes the precision@k for the specified values of k"""
    maxk = 1
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in [1]:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res[0]





    #handle the learning rate scheduler.

# print(simple_churn(target_var, output.argmax(1)))


In [None]:

mlflow.set_experiment('New run 2')

2022/11/14 04:32:01 INFO mlflow.tracking.fluent: Experiment with name 'New run 2' does not exist. Creating a new experiment.


<Experiment: artifact_location='file:///content/drive/MyDrive/Proofpoint/mlruns/2', creation_time=1668400321332, experiment_id='2', last_update_time=1668400321332, lifecycle_stage='active', name='New run 2', tags={}>

In [None]:
def baseline_chunk():
    DECAY_EPOCHS= [20,120]
    DECAY= 0.1
    num_epochs = 30
    best_acc=0
    current_learning_rate = 0.1
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    baseline_model = newmodel(torch.hub.load('pytorch/vision:v0.13.1', 'resnet18', pretrained=False))
    criterion = nn.CrossEntropyLoss()
    baseline_optimizer = torch.optim.SGD(baseline_model.parameters(), lr=0.1, momentum=0.9)

    baseline_model.cuda()
    
    for epoch in range(0,num_epochs):
        if epoch in DECAY_EPOCHS and epoch != 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in baseline_optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
        per_epoch_losses = []
        per_epoch_accuracies = []
        print(epoch)
        for j, (input, target) in enumerate(new_train_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target
            
            # compute output
            output = baseline_model(input_var)            
            loss = criterion(output, target_var)

            # compute gradient and do SGD step
            baseline_optimizer.zero_grad()
            loss.backward()
            baseline_optimizer.step()

            output = output.float()
            loss = loss.float()
            
            # measure accuracy and record loss
            per_epoch_accuracies.append(accuracy(output.data, target))
            per_epoch_losses.append(loss.item())

        per_epoch_val_losses = []
        per_epoch_val_accuracies = []
        for j, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target
            output = baseline_model(input_var)
            loss = criterion(output, target_var)
            output = output.float()
            loss = loss.float()
            
            # measure accuracy and record loss
            per_epoch_val_accuracies.append(accuracy(output.data, target))
            per_epoch_val_losses.append(loss.item())
      
        train_losses.append(sum(per_epoch_losses) / len(per_epoch_losses))
        val_losses.append(sum(per_epoch_val_losses) / len(per_epoch_val_losses))
        train_accuracies.append(sum(per_epoch_accuracies) / len(per_epoch_accuracies))
        val_accuracies.append(sum(per_epoch_val_accuracies) / len(per_epoch_val_accuracies))
      
        mlflow.log_metric("Baseline Train accuracy", train_accuracies[-1])
        mlflow.log_metric("Baseline Train loss", train_losses[-1])
        mlflow.log_metric("Baseline Val loss", val_losses[-1])
        mlflow.log_metric("Baseline Val accuracy", val_accuracies[-1])

        print("Training loss: %.4f, Training accuracy: %.4f, Val loss: %.4f, Val accuracy: %.4f, " %(train_losses[-1], train_accuracies[-1], val_losses[-1], val_accuracies[-1]))
        if val_accuracies[-1] > best_acc:
            best_acc = val_accuracies[-1]
            print('')

    print(f"=> Best: {best_acc:.4f}")
    return baseline_model



In [None]:
def teacher_chunk():
    model = newmodel(torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)  
    model.cuda()
    DECAY_EPOCHS= [20,120]
    DECAY= 0.1
    num_epochs  = 30
    best_acc=0
    current_learning_rate = 0.1
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    for epoch in range(0,num_epochs):
        print("Epoch : %d" % epoch)
        if epoch in DECAY_EPOCHS and epoch != 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
        per_epoch_losses = []
        per_epoch_accuracies = []
        for j, (input, target) in enumerate(old_train_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target
            
            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            output = output.float()
            loss = loss.float()
            
            # measure accuracy and record loss
            per_epoch_accuracies.append(accuracy(output.data, target))
            per_epoch_losses.append(loss.item())

        per_epoch_val_losses = []
        per_epoch_val_accuracies = []
        for j, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target
            output = model(input_var)
            loss = criterion(output, target_var)
            output = output.float()
            loss = loss.float()
            
            # measure accuracy and record loss
            per_epoch_val_accuracies.append(accuracy(output.data, target))
            per_epoch_val_losses.append(loss.item())
        
        train_losses.append(sum(per_epoch_losses) / len(per_epoch_losses))
        val_losses.append(sum(per_epoch_val_losses) / len(per_epoch_val_losses))
        train_accuracies.append(sum(per_epoch_accuracies) / len(per_epoch_accuracies))
        val_accuracies.append(sum(per_epoch_val_accuracies) / len(per_epoch_val_accuracies))
      
        mlflow.log_metric("Teacher Train accuracy", train_accuracies[-1])
        mlflow.log_metric("Teacher Train loss", train_losses[-1])
        mlflow.log_metric("Teacher Val loss", val_losses[-1])
        mlflow.log_metric("Teacher Val accuracy", val_accuracies[-1])

        print("Training loss: %.4f, Training accuracy: %.4f, Val loss: %.4f, Val accuracy: %.4f, " %(train_losses[-1], train_accuracies[-1], val_losses[-1], val_accuracies[-1]))
        if val_accuracies[-1] > best_acc:
            best_acc = val_accuracies[-1]
            print('')

    print(f"=> Best: {best_acc:.4f}")
    return model


In [None]:
def student_chunk_rcp(model):
    DECAY_EPOCHS= [20,120]
    DECAY= 0.1
    num_epochs  = 30
    alphas = [0.2, 0.4, 0.6, 0.8]
    epsilon = 1.    

    student_models = {}
    for alpha in alphas:

      best_acc=0
      current_learning_rate = 0.1
      train_losses = []
      val_losses = []
      train_accuracies = []
      val_accuracies = []
      label_modifier = Anchor_RCP(model, alpha=alpha, epsilon=epsilon)
        ### END OF CHANGES ####
      print(f"alpha = {alpha}")
      print(f"epsilon = {epsilon}")
      student_model = newmodel(torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False))
      student_models[(alpha, epsilon)] = student_model
      criterion = nn.CrossEntropyLoss()
      student_optimizer = torch.optim.SGD(student_model.parameters(), lr=0.1, momentum=0.9)

      student_model.cuda()
      model.eval()

      for epoch in range(0,num_epochs):
          print("Epoch : %d" % epoch)
          if epoch in DECAY_EPOCHS and epoch != 0:
              current_learning_rate = current_learning_rate * DECAY
              for param_group in student_optimizer.param_groups:
                  param_group['lr'] = current_learning_rate
              print("Current learning rate has decayed to %f" %current_learning_rate)
          per_epoch_losses = []
          per_epoch_accuracies = []
          for j, (input, target) in enumerate(old_train_loader):
              target = target.cuda()
              input_var = input.cuda()
              target_var = nn.functional.one_hot(target, num_classes= 10)
              
              # compute output
              output = student_model(input_var)
              new_label = label_modifier(input_var, target_var)
              loss = criterion(output, new_label)
              student_optimizer.zero_grad()
              loss.backward()
              student_optimizer.step()
              output = output.float()
              loss = loss.float()
              
              # measure accuracy and record loss
              per_epoch_accuracies.append(accuracy(output.data, target))
              per_epoch_losses.append(loss.item())

          per_epoch_val_losses = []
          per_epoch_val_accuracies = []
          for j, (input, target) in enumerate(val_loader):
              target = target.cuda()
              input_var = input.cuda()
              target_var = target
              output = student_model(input_var)
              loss = criterion(output, target_var)
              output = output.float()
              loss = loss.float()
              
              # measure accuracy and record loss
              per_epoch_val_accuracies.append(accuracy(output.data, target))
              per_epoch_val_losses.append(loss.item())
          
          train_losses.append(sum(per_epoch_losses) / len(per_epoch_losses))
          val_losses.append(sum(per_epoch_val_losses) / len(per_epoch_val_losses))
          train_accuracies.append(sum(per_epoch_accuracies) / len(per_epoch_accuracies))
          val_accuracies.append(sum(per_epoch_val_accuracies) / len(per_epoch_val_accuracies))
        
          mlflow.log_metric(f"Student RCP Train accuracy_{alpha}", train_accuracies[-1])
          mlflow.log_metric(f"Student RCP  Train loss_{alpha}", train_losses[-1])
          mlflow.log_metric(f"Student RCP  Val loss_{alpha}", val_losses[-1])
          mlflow.log_metric(f"Student RCP  Val accuracy_{alpha}", val_accuracies[-1])
          

          print("Training loss: %.4f, Training accuracy: %.4f, Val loss: %.4f, Val accuracy: %.4f, " %(train_losses[-1], train_accuracies[-1], val_losses[-1], val_accuracies[-1]))
          if val_accuracies[-1] > best_acc:
              best_acc = val_accuracies[-1]
              print('')

      print(f"=> Best: {best_acc:.4f}")
      torch.save(student_model.state_dict(), f"alpha={alpha}_epsilon_{epsilon}.model")
    return student_models

In [None]:
def student_chunk_dist(model):
    DECAY_EPOCHS= [20,120]
    DECAY= 0.1
    num_epochs  = 30
    
    lamdas = [0.2, 0.4, 0.6, 0.8]
    student_models = {}
    for lamda in lamdas:

      best_acc=0
      current_learning_rate = 0.1
      train_losses = []
      val_losses = []
      train_accuracies = []
      val_accuracies = []
      label_modifier = Distillation(model,lamda=lamda)
      print(f"lamda = {lamda}")
      student_model = newmodel(torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False))
      student_models[lamda] = student_model
      criterion = nn.CrossEntropyLoss()
      student_optimizer = torch.optim.SGD(student_model.parameters(), lr=0.1, momentum=0.9)

      student_model.cuda()
      model.eval()


      for epoch in range(0,num_epochs):
          print("Epoch : %d" % epoch)
          if epoch in DECAY_EPOCHS and epoch != 0:
              current_learning_rate = current_learning_rate * DECAY
              for param_group in student_optimizer.param_groups:
                  param_group['lr'] = current_learning_rate
              print("Current learning rate has decayed to %f" %current_learning_rate)
          per_epoch_losses = []
          per_epoch_accuracies = []
          for j, (input, target) in enumerate(new_train_loader):
              target = target.cuda()
              input_var = input.cuda()
              target_var = nn.functional.one_hot(target)

              # compute output
              output = student_model(input_var)
              teacher_label = nn.functional.softmax(model(input_var), dim=1)
              new_label = label_modifier(input_var, target_var)
              loss = criterion(output, new_label)

              # compute gradient and do SGD step
              student_optimizer.zero_grad()
              loss.backward()
              student_optimizer.step()

              output = output.float()
              loss = loss.float()
              
              # measure accuracy and record loss
              per_epoch_accuracies.append(accuracy(output.data, target))
              per_epoch_losses.append(loss.item())

          per_epoch_val_losses = []
          per_epoch_val_accuracies = []
          for j, (input, target) in enumerate(val_loader):
              target = target.cuda()
              input_var = input.cuda()
              target_var = target
              output = student_model(input_var)
              loss = criterion(output, target_var)
              output = output.float()
              loss = loss.float()
              
              # measure accuracy and record loss
              per_epoch_val_accuracies.append(accuracy(output.data, target))
              per_epoch_val_losses.append(loss.item())
          
          train_losses.append(sum(per_epoch_losses) / len(per_epoch_losses))
          val_losses.append(sum(per_epoch_val_losses) / len(per_epoch_val_losses))
          train_accuracies.append(sum(per_epoch_accuracies) / len(per_epoch_accuracies))
          val_accuracies.append(sum(per_epoch_val_accuracies) / len(per_epoch_val_accuracies))
        
          mlflow.log_metric(f"Student Distill Train accuracy_{lamda}", train_accuracies[-1])
          mlflow.log_metric(f"Student Distill Train loss_{lamda}", train_losses[-1])
          mlflow.log_metric(f"Student Distill Val loss_{lamda}", val_losses[-1])
          mlflow.log_metric(f"Student Distill Val accuracy_{lamda}", val_accuracies[-1])

          print("Training loss: %.4f, Training accuracy: %.4f, Val loss: %.4f, Val accuracy: %.4f, " %(train_losses[-1], train_accuracies[-1], val_losses[-1], val_accuracies[-1]))
          if val_accuracies[-1] > best_acc:
              best_acc = val_accuracies[-1]
              print('')

      print(f"=> Best: {best_acc:.4f}")
      torch.save(student_model.state_dict(), f"lambda{lamda}.modeel")
    return student_models

In [None]:
#torch.save(model.state_dict(), f"teacher.modeel")

In [None]:
def metric_chunk_dist(model, baseline_model, run_no):
  simple_churn = Churn(tensor_type="torch", output_mode="count")
  good_churn = GoodBadChurn(tensor_type="torch", mode="good", output_mode="count")
  bad_churn = GoodBadChurn(tensor_type="torch", mode="bad", output_mode="count")
  wlr = WinLossRatio(tensor_type="torch")

  baseline_wins = 0
  baseline_losses = 0
  baseline_good_churn = 0
  baseline_bad_churn = 0
  baseline_model.eval()
  churn_count_baseline = 0
  baseline_acc = 0
  teacher_acc = 0

  for i, (input, target) in enumerate(val_loader):
      input_var = input.cuda()
      target_var = target.cuda()
      target_var_oh = nn.functional.one_hot(target_var, num_classes=10)
      churn_count_baseline += simple_churn(model(input_var), baseline_model(input_var))
      baseline_acc += torch.sum(baseline_model(input_var).argmax(1) == target_var)
      teacher_acc += torch.sum(model(input_var).argmax(1) == target_var)
      baseline_win, baseline_loss = wlr(target_var_oh, model(input_var), baseline_model(input_var))
      baseline_wins += baseline_win
      baseline_losses += baseline_loss
      baseline_good_churn = good_churn(target_var_oh, model(input_var), baseline_model(input_var))
      baseline_bad_churn = bad_churn(target_var_oh, model(input_var), baseline_model(input_var))

  lamdas = [0.2, 0.4, 0.6, 0.8]

  for lamda in lamdas:
      student_model = newmodel(torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False))
      student_model.load_state_dict(torch.load(f"lambda{lamda}.modeel"))
      student_model.cuda()
      mlflow.pytorch.log_model(student_model, f"student_model_Val_{lamda}")
      
      simple_churn = Churn(tensor_type="torch", output_mode="count")
      good_churn = GoodBadChurn(tensor_type="torch", mode="good", output_mode="count")
      bad_churn = GoodBadChurn(tensor_type="torch", mode="bad", output_mode="count")
      wlr = WinLossRatio(tensor_type="torch")

      churn_count_student = 0
      student_acc = 0
      student_model.eval()
      student_wins = 0
      student_losses = 0
      student_good_churn = 0
      student_bad_churn = 0



      for i, (input, target) in enumerate(val_loader):
          input_var = input.cuda()
          target_var = target.cuda()
          target_var_oh = nn.functional.one_hot(target_var, num_classes=10)
          churn_count_student += simple_churn(model(input_var), student_model(input_var))        
          student_acc += torch.sum(student_model(input_var).argmax(1) == target_var)        
          student_win, student_loss = wlr(target_var_oh, model(input_var), student_model(input_var))        
          student_wins += student_win        
          student_losses += student_loss        
          student_good_churn += good_churn(target_var_oh, model(input_var), student_model(input_var))
          student_bad_churn = bad_churn(target_var_oh, model(input_var), student_model(input_var))
          

      churn_student = churn_count_student.item() / len(valset)
      churn_baseline = churn_count_baseline.item() / len(valset)
      print(f"\n\nlamda = {lamda}\n", )
      print(f"student churn = {churn_student}", )
      print(f"baseline churn = {churn_baseline}")
      print(f"student accuracy = {student_acc.item() / len(valset)}")
      print(f"baseline accuracy = {baseline_acc.item() / len(valset)}")
      print(f"teacher accuracy = {teacher_acc.item() / len(valset)}")
      print(f"student wlr = {(student_wins / student_losses)}")
      print(f"baseline wlr = {(baseline_wins / baseline_losses)}")
      print(f"churn ratio = {(churn_student/ churn_baseline)}")
      print(f"student good_churn = {student_good_churn.item() / len(valset)}")
      print(f"student bad_churn = {student_bad_churn.item() / len(valset)}")
      print(f"baseline good_churn = {baseline_good_churn.item() / len(valset)}")
      print(f"baseline bad_churn = {baseline_bad_churn.item() / len(valset)}")
      
      mlflow.log_metric(f"{run_no} student_churn_{lamda}", churn_student)
      mlflow.log_metric(f"{run_no} baseline churn_{lamda}", churn_baseline)
      mlflow.log_metric(f"{run_no} student accuracy_{lamda}", student_acc.item() / len(valset))
      mlflow.log_metric(f"{run_no} baseline accuracy_{lamda}", baseline_acc.item() / len(valset))
      mlflow.log_metric(f"{run_no} teacher accuracy _{lamda}", teacher_acc.item() / len(valset))
      mlflow.log_metric(f"{run_no} student wlr_{lamda}", (student_wins / student_losses))
      mlflow.log_metric(f"{run_no} baseline wlr_{lamda}", (baseline_wins / baseline_losses))
      mlflow.log_metric(f"{run_no} churn ratio_{lamda}", (churn_student/ churn_baseline))
      mlflow.log_metric(f"{run_no} student good_churn_{lamda}", student_good_churn.item() / len(valset))
      mlflow.log_metric(f"{run_no} student bad_churn _{lamda}", student_bad_churn.item() / len(valset))
      mlflow.log_metric(f"{run_no} baseline good_churn_{lamda}",baseline_good_churn.item() / len(valset))
      mlflow.log_metric(f"{run_no} baseline bad_churn_{lamda}", baseline_bad_churn.item() / len(valset))
      
      mlflow.pytorch.log_model(baseline_model, f"baseline_model_Val_{lamda}")




      

In [None]:
def metric_chunk_rcp(model, baseline_model, run_no):
  simple_churn = Churn(tensor_type="torch", output_mode="count")
  good_churn = GoodBadChurn(tensor_type="torch", mode="good", output_mode="count")
  bad_churn = GoodBadChurn(tensor_type="torch", mode="bad", output_mode="count")
  wlr = WinLossRatio(tensor_type="torch")

  baseline_wins = 0
  baseline_losses = 0
  baseline_good_churn = 0
  baseline_bad_churn = 0
  baseline_model.eval()
  churn_count_baseline = 0
  baseline_acc = 0
  teacher_acc = 0

  for i, (input, target) in enumerate(val_loader):
      input_var = input.cuda()
      target_var = target.cuda()
      target_var_oh = nn.functional.one_hot(target_var, num_classes=10)
      churn_count_baseline += simple_churn(model(input_var), baseline_model(input_var))
      baseline_acc += torch.sum(baseline_model(input_var).argmax(1) == target_var)
      teacher_acc += torch.sum(model(input_var).argmax(1) == target_var)
      baseline_win, baseline_loss = wlr(target_var_oh, model(input_var), baseline_model(input_var))
      baseline_wins += baseline_win
      baseline_losses += baseline_loss
      baseline_good_churn = good_churn(target_var_oh, model(input_var), baseline_model(input_var))
      baseline_bad_churn = bad_churn(target_var_oh, model(input_var), baseline_model(input_var))

  alphas = [0.2, 0.4, 0.6, 0.8]
  epsilon = 1.


  for alpha in alphas:
        student_model = newmodel(torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False))
        student_model.load_state_dict(torch.load(f"alpha={alpha}_epsilon_{epsilon}.model"))
        student_model.cuda()
        mlflow.pytorch.log_model(student_model, f"Student_RCP_accuracy_alpha_{alpha}_epsilon_{epsilon}.model")
        
        simple_churn = Churn(tensor_type="torch", output_mode="count")
        good_churn = GoodBadChurn(tensor_type="torch", mode="good", output_mode="count")
        bad_churn = GoodBadChurn(tensor_type="torch", mode="bad", output_mode="count")
        wlr = WinLossRatio(tensor_type="torch")

        churn_count_student = 0
        student_acc = 0
        student_model.eval()
        student_wins = 0
        student_losses = 0
        student_good_churn = 0
        student_bad_churn = 0



        for i, (input, target) in enumerate(val_loader):
            input_var = input.cuda()
            target_var = target.cuda()
            target_var_oh = nn.functional.one_hot(target_var, num_classes=10)
            churn_count_student += simple_churn(model(input_var), student_model(input_var))        
            student_acc += torch.sum(student_model(input_var).argmax(1) == target_var)        
            student_win, student_loss = wlr(target_var_oh, model(input_var), student_model(input_var))        
            student_wins += student_win        
            student_losses += student_loss        
            student_good_churn += good_churn(target_var_oh, model(input_var), student_model(input_var))
            student_bad_churn = bad_churn(target_var_oh, model(input_var), student_model(input_var))
            

        churn_student = churn_count_student.item() / len(valset)
        churn_baseline = churn_count_baseline.item() / len(valset)
        print(f"\n\nalpha = {alpha}\n", )
        print(f"\n\nEpsilon = {epsilon}\n",)
        print(f"student churn = {churn_student}", )
        print(f"baseline churn = {churn_baseline}")
        print(f"student accuracy = {student_acc.item() / len(valset)}")
        print(f"baseline accuracy = {baseline_acc.item() / len(valset)}")
        print(f"teacher accuracy = {teacher_acc.item() / len(valset)}")
        print(f"student wlr = {(student_wins / student_losses)}")
        print(f"baseline wlr = {(baseline_wins / baseline_losses)}")
        print(f"churn ratio = {(churn_student/ churn_baseline)}")
        print(f"student good_churn = {student_good_churn.item() / len(valset)}")
        print(f"student bad_churn = {student_bad_churn.item() / len(valset)}")
        print(f"baseline good_churn = {baseline_good_churn.item() / len(valset)}")
        print(f"baseline bad_churn = {baseline_bad_churn.item() / len(valset)}")
        
        mlflow.log_metric(f"{run_no} student_churn_alpha_{alpha}_epsilon_{epsilon}", churn_student)
        mlflow.log_metric(f"{run_no} baseline churn_alpha_{alpha}_epsilon_{epsilon}", churn_baseline)
        mlflow.log_metric(f"{run_no} student accuracy_alpha_{alpha}_epsilon_{epsilon}", student_acc.item() / len(valset))
        mlflow.log_metric(f"{run_no} baseline accuracy_alpha_{alpha}_epsilon_{epsilon}", baseline_acc.item() / len(valset))
        mlflow.log_metric(f"{run_no} teacher accuracy _alpha_{alpha}_epsilon_{epsilon}", teacher_acc.item() / len(valset))
        mlflow.log_metric(f"{run_no} student wlr_alpha_{alpha}_epsilon_{epsilon}", (student_wins / student_losses))
        mlflow.log_metric(f"{run_no} baseline wlr_alpha_{alpha}_epsilon_{epsilon}", (baseline_wins / baseline_losses))
        mlflow.log_metric(f"{run_no} churn ratio_alpha_{alpha}_epsilon_{epsilon}", (churn_student/ churn_baseline))
        mlflow.log_metric(f"{run_no} student good_churn_{alpha}_epsilon_{epsilon}", student_good_churn.item() / len(valset))
        mlflow.log_metric(f"{run_no} student bad_churn _alpha_{alpha}_epsilon_{epsilon}", student_bad_churn.item() / len(valset))
        mlflow.log_metric(f"{run_no} baseline good_churn_alpha_{alpha}_epsilon_{epsilon}",baseline_good_churn.item() / len(valset))
        mlflow.log_metric(f"{run_no} baseline bad_churn_alpha_{alpha}_epsilon_{epsilon}", baseline_bad_churn.item() / len(valset))
        
        mlflow.pytorch.log_model(baseline_model, f"RCP baseline_model_Val_{alpha}_epsilon_{epsilon}")




      

Main loop for MLFlow:

In [None]:
with mlflow.start_run(run_name="Ten Runs"):
    for run_no in range(10):
        teacher_model = teacher_chunk()
        rcp_models = student_chunk_rcp(teacher_model)
        dist_models = student_chunk_dist(teacher_model)
        base_model = baseline_chunk()
        metric_chunk_rcp(teacher_model, base_model, run_no)
        metric_chunk_dist(teacher_model, base_model, run_no)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


Epoch : 0
Training loss: 1.8454, Training accuracy: 31.6035, Val loss: 1.7637, Val accuracy: 35.7496, 

Epoch : 1
Training loss: 1.5903, Training accuracy: 42.0756, Val loss: 1.5129, Val accuracy: 45.1048, 

Epoch : 2
Training loss: 1.4101, Training accuracy: 49.0559, Val loss: 1.3630, Val accuracy: 51.8790, 

Epoch : 3
Training loss: 1.2824, Training accuracy: 54.6144, Val loss: 1.3046, Val accuracy: 55.7753, 

Epoch : 4
Training loss: 1.1684, Training accuracy: 58.8597, Val loss: 1.1646, Val accuracy: 58.8212, 

Epoch : 5
Training loss: 1.0897, Training accuracy: 61.5913, Val loss: 1.0963, Val accuracy: 61.0562, 

Epoch : 6
Training loss: 1.0214, Training accuracy: 64.0913, Val loss: 1.0717, Val accuracy: 63.0044, 

Epoch : 7
Training loss: 0.9648, Training accuracy: 65.9497, Val loss: 1.0343, Val accuracy: 64.3592, 

Epoch : 8
Training loss: 0.9235, Training accuracy: 67.6031, Val loss: 0.9567, Val accuracy: 66.7029, 

Epoch : 9
Training loss: 0.8678, Training accuracy: 70.0820, Val

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8422, Training accuracy: 32.1775, Val loss: 1.6435, Val accuracy: 40.5360, 

Epoch : 1
Training loss: 1.6042, Training accuracy: 42.1509, Val loss: 1.5578, Val accuracy: 42.0589, 

Epoch : 2
Training loss: 1.4263, Training accuracy: 49.0304, Val loss: 1.3973, Val accuracy: 48.3287, 

Epoch : 3
Training loss: 1.3229, Training accuracy: 53.5040, Val loss: 1.2680, Val accuracy: 55.6665, 

Epoch : 4
Training loss: 1.1756, Training accuracy: 58.5539, Val loss: 1.1570, Val accuracy: 58.9893, 

Epoch : 5
Training loss: 1.0753, Training accuracy: 62.7050, Val loss: 1.1098, Val accuracy: 61.7188, 

Epoch : 6
Training loss: 1.0149, Training accuracy: 65.0620, Val loss: 1.0522, Val accuracy: 63.0340, 

Epoch : 7
Training loss: 0.9595, Training accuracy: 67.0900, Val loss: 0.9518, Val accuracy: 67.0985, 

Epoch : 8
Training loss: 0.9064, Training accuracy: 69.1245, Val loss: 0.9321, Val accuracy: 67.4842, 

Epoch : 9
Training loss: 0.8568, Training accuracy: 70.9242, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8388, Training accuracy: 32.4468, Val loss: 1.6555, Val accuracy: 38.4988, 

Epoch : 1
Training loss: 1.5707, Training accuracy: 43.1006, Val loss: 1.4437, Val accuracy: 48.0222, 

Epoch : 2
Training loss: 1.3797, Training accuracy: 51.0417, Val loss: 1.3113, Val accuracy: 52.5910, 

Epoch : 3
Training loss: 1.2535, Training accuracy: 55.9586, Val loss: 1.2124, Val accuracy: 56.9225, 

Epoch : 4
Training loss: 1.1434, Training accuracy: 60.2338, Val loss: 1.1274, Val accuracy: 60.9177, 

Epoch : 5
Training loss: 1.0711, Training accuracy: 63.2136, Val loss: 1.0602, Val accuracy: 63.0241, 

Epoch : 6
Training loss: 1.0035, Training accuracy: 65.8754, Val loss: 1.0269, Val accuracy: 63.7757, 

Epoch : 7
Training loss: 0.9594, Training accuracy: 67.3116, Val loss: 0.9759, Val accuracy: 66.4260, 

Epoch : 8
Training loss: 0.9098, Training accuracy: 69.0171, Val loss: 0.9788, Val accuracy: 66.1788, 
Epoch : 9
Training loss: 0.8732, Training accuracy: 70.6261, Val loss: 0.86

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8243, Training accuracy: 33.4375, Val loss: 1.6276, Val accuracy: 39.8141, 

Epoch : 1
Training loss: 1.5620, Training accuracy: 43.7422, Val loss: 1.4502, Val accuracy: 47.1717, 

Epoch : 2
Training loss: 1.3783, Training accuracy: 51.4040, Val loss: 1.2992, Val accuracy: 53.2931, 

Epoch : 3
Training loss: 1.2473, Training accuracy: 56.8296, Val loss: 1.1431, Val accuracy: 60.1958, 

Epoch : 4
Training loss: 1.1423, Training accuracy: 60.9375, Val loss: 1.1314, Val accuracy: 60.2848, 

Epoch : 5
Training loss: 1.0637, Training accuracy: 63.9428, Val loss: 1.0852, Val accuracy: 61.4419, 

Epoch : 6
Training loss: 1.0082, Training accuracy: 65.8400, Val loss: 1.0082, Val accuracy: 65.1009, 

Epoch : 7
Training loss: 0.9444, Training accuracy: 68.4386, Val loss: 0.9618, Val accuracy: 66.6337, 

Epoch : 8
Training loss: 0.9083, Training accuracy: 69.9457, Val loss: 0.9049, Val accuracy: 68.4731, 

Epoch : 9
Training loss: 0.8549, Training accuracy: 71.7376, Val loss: 0.8

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8335, Training accuracy: 32.6053, Val loss: 1.7399, Val accuracy: 35.2354, 

Epoch : 1
Training loss: 1.5698, Training accuracy: 43.7633, Val loss: 1.4827, Val accuracy: 46.2520, 

Epoch : 2
Training loss: 1.4013, Training accuracy: 50.7425, Val loss: 1.3354, Val accuracy: 50.6527, 

Epoch : 3
Training loss: 1.2648, Training accuracy: 56.0007, Val loss: 1.2091, Val accuracy: 57.1203, 

Epoch : 4
Training loss: 1.1562, Training accuracy: 60.4776, Val loss: 1.1183, Val accuracy: 59.8991, 

Epoch : 5
Training loss: 1.1155, Training accuracy: 62.1077, Val loss: 1.0590, Val accuracy: 62.5989, 

Epoch : 6
Training loss: 1.0304, Training accuracy: 65.3568, Val loss: 1.0301, Val accuracy: 64.3888, 

Epoch : 7
Training loss: 0.9790, Training accuracy: 67.5964, Val loss: 0.9597, Val accuracy: 66.0601, 

Epoch : 8
Training loss: 0.9502, Training accuracy: 68.7744, Val loss: 0.9453, Val accuracy: 67.1578, 

Epoch : 9
Training loss: 0.9118, Training accuracy: 69.9856, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7522, Training accuracy: 35.7803, Val loss: 1.5831, Val accuracy: 41.6040, 

Epoch : 1
Training loss: 1.4399, Training accuracy: 48.2478, Val loss: 1.4275, Val accuracy: 49.1396, 

Epoch : 2
Training loss: 1.2355, Training accuracy: 56.0828, Val loss: 1.2371, Val accuracy: 56.3983, 

Epoch : 3
Training loss: 1.1055, Training accuracy: 61.0448, Val loss: 1.0881, Val accuracy: 61.7385, 

Epoch : 4
Training loss: 1.0020, Training accuracy: 64.8537, Val loss: 1.0493, Val accuracy: 63.3505, 

Epoch : 5
Training loss: 0.9327, Training accuracy: 67.4795, Val loss: 0.9637, Val accuracy: 67.0688, 

Epoch : 6
Training loss: 0.8653, Training accuracy: 70.0679, Val loss: 0.9403, Val accuracy: 67.2073, 

Epoch : 7
Training loss: 0.8279, Training accuracy: 71.4457, Val loss: 0.8791, Val accuracy: 69.3631, 

Epoch : 8
Training loss: 0.7759, Training accuracy: 73.3701, Val loss: 0.8352, Val accuracy: 71.6278, 

Epoch : 9
Training loss: 0.7471, Training accuracy: 74.3161, Val loss: 0.8

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7411, Training accuracy: 35.9150, Val loss: 1.5591, Val accuracy: 43.0380, 

Epoch : 1
Training loss: 1.4138, Training accuracy: 48.8768, Val loss: 1.3249, Val accuracy: 52.7492, 

Epoch : 2
Training loss: 1.1946, Training accuracy: 57.2359, Val loss: 1.2365, Val accuracy: 56.8434, 

Epoch : 3
Training loss: 1.0745, Training accuracy: 62.2329, Val loss: 1.0965, Val accuracy: 60.8683, 

Epoch : 4
Training loss: 0.9772, Training accuracy: 65.7872, Val loss: 1.0078, Val accuracy: 64.6855, 

Epoch : 5
Training loss: 0.9036, Training accuracy: 68.2907, Val loss: 0.9376, Val accuracy: 67.7907, 

Epoch : 6
Training loss: 0.8441, Training accuracy: 70.6170, Val loss: 0.8809, Val accuracy: 69.5609, 

Epoch : 7
Training loss: 0.7961, Training accuracy: 72.5539, Val loss: 0.8638, Val accuracy: 70.4608, 

Epoch : 8
Training loss: 0.7570, Training accuracy: 74.2637, Val loss: 0.8263, Val accuracy: 71.9244, 

Epoch : 9
Training loss: 0.7187, Training accuracy: 75.4892, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7179, Training accuracy: 36.6389, Val loss: 1.5514, Val accuracy: 43.3643, 

Epoch : 1
Training loss: 1.3829, Training accuracy: 50.0649, Val loss: 1.2896, Val accuracy: 53.1349, 

Epoch : 2
Training loss: 1.1603, Training accuracy: 58.9906, Val loss: 1.1535, Val accuracy: 59.7211, 

Epoch : 3
Training loss: 1.0136, Training accuracy: 64.1648, Val loss: 1.0258, Val accuracy: 63.9241, 

Epoch : 4
Training loss: 0.9160, Training accuracy: 68.1659, Val loss: 0.9438, Val accuracy: 67.0490, 

Epoch : 5
Training loss: 0.8408, Training accuracy: 70.7044, Val loss: 0.8887, Val accuracy: 68.8489, 

Epoch : 6
Training loss: 0.7859, Training accuracy: 72.8410, Val loss: 0.8394, Val accuracy: 71.4794, 

Epoch : 7
Training loss: 0.7404, Training accuracy: 74.6331, Val loss: 0.7901, Val accuracy: 72.8343, 

Epoch : 8
Training loss: 0.7115, Training accuracy: 75.7738, Val loss: 0.7765, Val accuracy: 73.4771, 

Epoch : 9
Training loss: 0.6786, Training accuracy: 77.0892, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7179, Training accuracy: 36.5141, Val loss: 1.5178, Val accuracy: 44.6499, 

Epoch : 1
Training loss: 1.3642, Training accuracy: 50.5666, Val loss: 1.3525, Val accuracy: 51.2362, 

Epoch : 2
Training loss: 1.1547, Training accuracy: 59.0480, Val loss: 1.1070, Val accuracy: 61.2144, 

Epoch : 3
Training loss: 1.0018, Training accuracy: 64.4843, Val loss: 0.9893, Val accuracy: 65.0119, 

Epoch : 4
Training loss: 0.9093, Training accuracy: 67.6992, Val loss: 0.9334, Val accuracy: 67.2765, 

Epoch : 5
Training loss: 0.8310, Training accuracy: 70.7393, Val loss: 0.8692, Val accuracy: 69.6203, 

Epoch : 6
Training loss: 0.7785, Training accuracy: 72.5689, Val loss: 0.8651, Val accuracy: 70.5498, 

Epoch : 7
Training loss: 0.7376, Training accuracy: 73.8593, Val loss: 0.8167, Val accuracy: 71.9343, 

Epoch : 8
Training loss: 0.7074, Training accuracy: 75.4593, Val loss: 0.7766, Val accuracy: 73.3386, 

Epoch : 9
Training loss: 0.6734, Training accuracy: 76.7622, Val loss: 0.7

Downloading: "https://github.com/pytorch/vision/zipball/v0.13.1" to /root/.cache/torch/hub/v0.13.1.zip


0
Training loss: 1.8048, Training accuracy: 33.4590, Val loss: 1.5977, Val accuracy: 42.7809, 

1
Training loss: 1.4601, Training accuracy: 46.8426, Val loss: 1.3801, Val accuracy: 49.8912, 

2
Training loss: 1.2785, Training accuracy: 54.3705, Val loss: 1.2178, Val accuracy: 56.9620, 

3
Training loss: 1.1362, Training accuracy: 59.7769, Val loss: 1.1226, Val accuracy: 61.4023, 

4
Training loss: 1.0419, Training accuracy: 63.3786, Val loss: 1.0776, Val accuracy: 62.6681, 

5
Training loss: 0.9790, Training accuracy: 65.7947, Val loss: 0.9867, Val accuracy: 64.9130, 

6
Training loss: 0.9174, Training accuracy: 67.8215, Val loss: 0.9446, Val accuracy: 67.3556, 

7
Training loss: 0.8590, Training accuracy: 70.1003, Val loss: 0.8937, Val accuracy: 69.2939, 

8
Training loss: 0.8180, Training accuracy: 71.8251, Val loss: 0.8826, Val accuracy: 69.5609, 

9
Training loss: 0.7896, Training accuracy: 72.7137, Val loss: 0.8669, Val accuracy: 69.5411, 
10
Training loss: 0.7692, Training accura

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.2



Epsilon = 1.0

student churn = 0.1627
baseline churn = 0.1716
student accuracy = 0.792
baseline accuracy = 0.8076
teacher accuracy = 0.7863
student wlr = 1.031301498413086
baseline wlr = 1.390109896659851
churn ratio = 0.9481351981351982
student good_churn = 0.0626
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0003


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "




alpha = 0.4



Epsilon = 1.0

student churn = 0.167
baseline churn = 0.1716
student accuracy = 0.7889
baseline accuracy = 0.8076
teacher accuracy = 0.7863
student wlr = 1.0197869539260864
baseline wlr = 1.390109896659851
churn ratio = 0.9731934731934733
student good_churn = 0.067
student bad_churn = 0.0002
baseline good_churn = 0.0001
baseline bad_churn = 0.0003


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.6



Epsilon = 1.0

student churn = 0.1646
baseline churn = 0.1716
student accuracy = 0.7885
baseline accuracy = 0.8076
teacher accuracy = 0.7863
student wlr = 1.0739495754241943
baseline wlr = 1.390109896659851
churn ratio = 0.9592074592074592
student good_churn = 0.0639
student bad_churn = 0.0002
baseline good_churn = 0.0001
baseline bad_churn = 0.0003


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.8



Epsilon = 1.0

student churn = 0.1665
baseline churn = 0.1716
student accuracy = 0.7809
baseline accuracy = 0.8076
teacher accuracy = 0.7863
student wlr = 0.9567233324050903
baseline wlr = 1.390109896659851
churn ratio = 0.9702797202797203
student good_churn = 0.0619
student bad_churn = 0.0
baseline good_churn = 0.0001
baseline bad_churn = 0.0003


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.2

student churn = 0.164
baseline churn = 0.1625
student accuracy = 0.804
baseline accuracy = 0.808
teacher accuracy = 0.791
student wlr = 1.2620320320129395
baseline wlr = 1.3142329454421997
churn ratio = 1.0092307692307692
student good_churn = 0.0708
student bad_churn = 0.0003
baseline good_churn = 0.0001
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.4

student churn = 0.1456
baseline churn = 0.1625
student accuracy = 0.8097
baseline accuracy = 0.808
teacher accuracy = 0.791
student wlr = 1.5011235475540161
baseline wlr = 1.3142329454421997
churn ratio = 0.896
student good_churn = 0.0668
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.6

student churn = 0.1373
baseline churn = 0.1625
student accuracy = 0.8076
baseline accuracy = 0.808
teacher accuracy = 0.791
student wlr = 1.4562647342681885
baseline wlr = 1.3142329454421997
churn ratio = 0.8449230769230769
student good_churn = 0.0616
student bad_churn = 0.0
baseline good_churn = 0.0001
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.8

student churn = 0.1278
baseline churn = 0.1625
student accuracy = 0.8005
baseline accuracy = 0.808
teacher accuracy = 0.791
student wlr = 1.3679012060165405
baseline wlr = 1.3142329454421997
churn ratio = 0.7864615384615384
student good_churn = 0.0554
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Epoch : 0
Training loss: 1.8209, Training accuracy: 32.6507, Val loss: 1.6359, Val accuracy: 39.9328, 

Epoch : 1
Training loss: 1.5537, Training accuracy: 43.5561, Val loss: 1.4786, Val accuracy: 45.1543, 

Epoch : 2
Training loss: 1.3651, Training accuracy: 51.2134, Val loss: 1.3163, Val accuracy: 52.4525, 

Epoch : 3
Training loss: 1.2396, Training accuracy: 56.1835, Val loss: 1.2298, Val accuracy: 55.3006, 

Epoch : 4
Training loss: 1.1449, Training accuracy: 59.4836, Val loss: 1.1626, Val accuracy: 59.3651, 

Epoch : 5
Training loss: 1.0535, Training accuracy: 62.8923, Val loss: 1.0876, Val accuracy: 62.5791, 

Epoch : 6
Training loss: 1.0006, Training accuracy: 64.9701, Val loss: 1.0289, Val accuracy: 64.7152, 

Epoch : 7
Training loss: 0.9412, Training accuracy: 67.0047, Val loss: 1.0363, Val accuracy: 63.8054, 
Epoch : 8
Training loss: 0.9257, Training accuracy: 67.5831, Val loss: 0.9336, Val accuracy: 67.2468, 

Epoch : 9
Training loss: 0.8621, Training accuracy: 70.1740, Val 

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8484, Training accuracy: 31.9814, Val loss: 1.6771, Val accuracy: 39.0328, 

Epoch : 1
Training loss: 1.5758, Training accuracy: 42.6152, Val loss: 1.4885, Val accuracy: 44.4422, 

Epoch : 2
Training loss: 1.4094, Training accuracy: 49.2775, Val loss: 1.3433, Val accuracy: 51.9086, 

Epoch : 3
Training loss: 1.2782, Training accuracy: 55.1186, Val loss: 1.2682, Val accuracy: 54.1733, 

Epoch : 4
Training loss: 1.1745, Training accuracy: 58.8231, Val loss: 1.1594, Val accuracy: 59.2563, 

Epoch : 5
Training loss: 1.0985, Training accuracy: 61.4993, Val loss: 1.1189, Val accuracy: 60.3936, 

Epoch : 6
Training loss: 1.0282, Training accuracy: 64.1201, Val loss: 1.0447, Val accuracy: 63.4494, 

Epoch : 7
Training loss: 0.9705, Training accuracy: 66.5171, Val loss: 1.0347, Val accuracy: 63.5384, 

Epoch : 8
Training loss: 0.9341, Training accuracy: 67.8557, Val loss: 0.9381, Val accuracy: 67.7314, 

Epoch : 9
Training loss: 0.8853, Training accuracy: 69.4625, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8295, Training accuracy: 32.1875, Val loss: 1.6824, Val accuracy: 38.1725, 

Epoch : 1
Training loss: 1.5746, Training accuracy: 43.0419, Val loss: 1.5291, Val accuracy: 44.0467, 

Epoch : 2
Training loss: 1.3756, Training accuracy: 50.9264, Val loss: 1.3022, Val accuracy: 54.1139, 

Epoch : 3
Training loss: 1.2500, Training accuracy: 56.2289, Val loss: 1.2207, Val accuracy: 56.5961, 

Epoch : 4
Training loss: 1.1525, Training accuracy: 60.0355, Val loss: 1.1567, Val accuracy: 59.5036, 

Epoch : 5
Training loss: 1.0783, Training accuracy: 63.1549, Val loss: 1.0440, Val accuracy: 63.4197, 

Epoch : 6
Training loss: 1.0158, Training accuracy: 65.4555, Val loss: 1.0202, Val accuracy: 64.2010, 

Epoch : 7
Training loss: 0.9707, Training accuracy: 67.0501, Val loss: 1.0218, Val accuracy: 65.3976, 

Epoch : 8
Training loss: 0.9172, Training accuracy: 68.9539, Val loss: 0.9391, Val accuracy: 67.5336, 

Epoch : 9
Training loss: 0.8819, Training accuracy: 70.4344, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8190, Training accuracy: 33.2879, Val loss: 1.6577, Val accuracy: 37.6582, 

Epoch : 1
Training loss: 1.5465, Training accuracy: 45.0089, Val loss: 1.4322, Val accuracy: 47.6661, 

Epoch : 2
Training loss: 1.3605, Training accuracy: 51.6090, Val loss: 1.3291, Val accuracy: 52.3438, 

Epoch : 3
Training loss: 1.2532, Training accuracy: 56.2434, Val loss: 1.2001, Val accuracy: 57.2686, 

Epoch : 4
Training loss: 1.1481, Training accuracy: 60.6316, Val loss: 1.1488, Val accuracy: 60.2551, 

Epoch : 5
Training loss: 1.0775, Training accuracy: 63.1516, Val loss: 1.0786, Val accuracy: 62.3418, 

Epoch : 6
Training loss: 1.0174, Training accuracy: 65.7114, Val loss: 1.0295, Val accuracy: 63.5779, 

Epoch : 7
Training loss: 0.9657, Training accuracy: 67.7704, Val loss: 0.9475, Val accuracy: 66.9205, 

Epoch : 8
Training loss: 0.9239, Training accuracy: 69.6221, Val loss: 0.9326, Val accuracy: 67.6721, 

Epoch : 9
Training loss: 0.8804, Training accuracy: 71.3974, Val loss: 0.8

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8335, Training accuracy: 32.6640, Val loss: 1.6457, Val accuracy: 39.7547, 

Epoch : 1
Training loss: 1.5736, Training accuracy: 43.2469, Val loss: 1.4991, Val accuracy: 46.0443, 

Epoch : 2
Training loss: 1.4072, Training accuracy: 50.1285, Val loss: 1.3477, Val accuracy: 51.8592, 

Epoch : 3
Training loss: 1.2657, Training accuracy: 55.8644, Val loss: 1.1697, Val accuracy: 58.0202, 

Epoch : 4
Training loss: 1.1643, Training accuracy: 60.4222, Val loss: 1.1224, Val accuracy: 59.7805, 

Epoch : 5
Training loss: 1.0772, Training accuracy: 63.8132, Val loss: 1.0865, Val accuracy: 61.2638, 

Epoch : 6
Training loss: 1.0159, Training accuracy: 66.2090, Val loss: 1.0111, Val accuracy: 63.4296, 

Epoch : 7
Training loss: 0.9715, Training accuracy: 68.3810, Val loss: 0.9305, Val accuracy: 67.4644, 

Epoch : 8
Training loss: 0.9217, Training accuracy: 70.1252, Val loss: 0.9166, Val accuracy: 67.5831, 

Epoch : 9
Training loss: 0.8886, Training accuracy: 71.1813, Val loss: 0.8

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7838, Training accuracy: 34.1254, Val loss: 1.5785, Val accuracy: 41.3271, 

Epoch : 1
Training loss: 1.4547, Training accuracy: 47.3018, Val loss: 1.3453, Val accuracy: 50.8109, 

Epoch : 2
Training loss: 1.2399, Training accuracy: 56.0678, Val loss: 1.2517, Val accuracy: 56.0819, 

Epoch : 3
Training loss: 1.1008, Training accuracy: 61.5665, Val loss: 1.0699, Val accuracy: 62.8461, 

Epoch : 4
Training loss: 0.9949, Training accuracy: 65.4228, Val loss: 1.0039, Val accuracy: 65.3382, 

Epoch : 5
Training loss: 0.9197, Training accuracy: 68.0486, Val loss: 0.9542, Val accuracy: 66.9798, 

Epoch : 6
Training loss: 0.8657, Training accuracy: 70.2376, Val loss: 0.9150, Val accuracy: 68.4434, 

Epoch : 7
Training loss: 0.8312, Training accuracy: 71.3334, Val loss: 0.8637, Val accuracy: 70.1246, 

Epoch : 8
Training loss: 0.7915, Training accuracy: 72.8235, Val loss: 0.8351, Val accuracy: 71.3014, 

Epoch : 9
Training loss: 0.7435, Training accuracy: 74.6506, Val loss: 0.8

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7224, Training accuracy: 36.3968, Val loss: 1.5239, Val accuracy: 45.1938, 

Epoch : 1
Training loss: 1.3730, Training accuracy: 50.7713, Val loss: 1.3459, Val accuracy: 52.4822, 

Epoch : 2
Training loss: 1.1683, Training accuracy: 58.3217, Val loss: 1.1550, Val accuracy: 60.0870, 

Epoch : 3
Training loss: 1.0506, Training accuracy: 63.4684, Val loss: 1.0520, Val accuracy: 62.7868, 

Epoch : 4
Training loss: 0.9430, Training accuracy: 67.0253, Val loss: 0.9633, Val accuracy: 67.1282, 

Epoch : 5
Training loss: 0.8656, Training accuracy: 70.0904, Val loss: 0.8751, Val accuracy: 69.7093, 

Epoch : 6
Training loss: 0.8254, Training accuracy: 71.5730, Val loss: 0.8649, Val accuracy: 70.5202, 

Epoch : 7
Training loss: 0.7686, Training accuracy: 73.9567, Val loss: 0.8146, Val accuracy: 72.4090, 

Epoch : 8
Training loss: 0.7284, Training accuracy: 75.1498, Val loss: 0.7914, Val accuracy: 72.9826, 

Epoch : 9
Training loss: 0.7051, Training accuracy: 76.2305, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7700, Training accuracy: 35.0115, Val loss: 1.5519, Val accuracy: 43.5918, 

Epoch : 1
Training loss: 1.4398, Training accuracy: 48.4225, Val loss: 1.3133, Val accuracy: 52.8382, 

Epoch : 2
Training loss: 1.2132, Training accuracy: 56.8291, Val loss: 1.1723, Val accuracy: 58.9300, 

Epoch : 3
Training loss: 1.0604, Training accuracy: 62.7771, Val loss: 1.0464, Val accuracy: 63.8548, 

Epoch : 4
Training loss: 0.9456, Training accuracy: 66.7232, Val loss: 0.9326, Val accuracy: 67.7512, 

Epoch : 5
Training loss: 0.8605, Training accuracy: 69.9905, Val loss: 0.8847, Val accuracy: 69.0961, 

Epoch : 6
Training loss: 0.8086, Training accuracy: 72.0572, Val loss: 0.8629, Val accuracy: 70.1048, 

Epoch : 7
Training loss: 0.7582, Training accuracy: 73.9791, Val loss: 0.8128, Val accuracy: 72.3002, 

Epoch : 8
Training loss: 0.7244, Training accuracy: 75.5491, Val loss: 0.8031, Val accuracy: 72.3794, 

Epoch : 9
Training loss: 0.6931, Training accuracy: 76.7871, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7167, Training accuracy: 36.8910, Val loss: 1.5732, Val accuracy: 43.2160, 

Epoch : 1
Training loss: 1.3737, Training accuracy: 50.4518, Val loss: 1.3257, Val accuracy: 52.9272, 

Epoch : 2
Training loss: 1.1644, Training accuracy: 58.5363, Val loss: 1.1641, Val accuracy: 59.7903, 

Epoch : 3
Training loss: 1.0166, Training accuracy: 64.1623, Val loss: 1.0353, Val accuracy: 62.9549, 

Epoch : 4
Training loss: 0.9048, Training accuracy: 68.0137, Val loss: 0.9423, Val accuracy: 66.9601, 

Epoch : 5
Training loss: 0.8237, Training accuracy: 71.0613, Val loss: 0.8850, Val accuracy: 69.4818, 

Epoch : 6
Training loss: 0.7860, Training accuracy: 72.7910, Val loss: 0.8445, Val accuracy: 70.8960, 

Epoch : 7
Training loss: 0.7397, Training accuracy: 74.4833, Val loss: 0.7989, Val accuracy: 72.8639, 

Epoch : 8
Training loss: 0.7010, Training accuracy: 75.8761, Val loss: 0.7837, Val accuracy: 72.7156, 
Epoch : 9
Training loss: 0.6815, Training accuracy: 76.8096, Val loss: 0.75

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.13.1


Training loss: 1.7869, Training accuracy: 34.0880, Val loss: 1.6270, Val accuracy: 37.6780, 

1
Training loss: 1.4640, Training accuracy: 46.8151, Val loss: 1.3923, Val accuracy: 50.0297, 

2
Training loss: 1.2914, Training accuracy: 54.2507, Val loss: 1.2770, Val accuracy: 54.4007, 

3
Training loss: 1.1616, Training accuracy: 59.0031, Val loss: 1.1777, Val accuracy: 58.2971, 

4
Training loss: 1.0640, Training accuracy: 62.5824, Val loss: 1.0802, Val accuracy: 61.7583, 

5
Training loss: 1.0009, Training accuracy: 65.1782, Val loss: 1.0492, Val accuracy: 62.4506, 

6
Training loss: 0.9303, Training accuracy: 67.8065, Val loss: 0.9626, Val accuracy: 66.6535, 

7
Training loss: 0.8870, Training accuracy: 69.1194, Val loss: 0.9544, Val accuracy: 67.5040, 

8
Training loss: 0.8367, Training accuracy: 70.8342, Val loss: 0.8787, Val accuracy: 69.6598, 

9
Training loss: 0.7903, Training accuracy: 72.5614, Val loss: 0.8720, Val accuracy: 70.1839, 

10
Training loss: 0.7714, Training accurac

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.2



Epsilon = 1.0

student churn = 0.1741
baseline churn = 0.1735
student accuracy = 0.7761
baseline accuracy = 0.8032
teacher accuracy = 0.7789
student wlr = 0.9293078184127808
baseline wlr = 1.4410163164138794
churn ratio = 1.0034582132564842
student good_churn = 0.0631
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.4



Epsilon = 1.0

student churn = 0.1694
baseline churn = 0.1735
student accuracy = 0.7817
baseline accuracy = 0.8032
teacher accuracy = 0.7789
student wlr = 1.0455259084701538
baseline wlr = 1.4410163164138794
churn ratio = 0.9763688760806917
student good_churn = 0.0666
student bad_churn = 0.0
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.6



Epsilon = 1.0

student churn = 0.1682
baseline churn = 0.1735
student accuracy = 0.7863
baseline accuracy = 0.8032
teacher accuracy = 0.7789
student wlr = 1.125619888305664
baseline wlr = 1.4410163164138794
churn ratio = 0.9694524495677234
student good_churn = 0.0681
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.8



Epsilon = 1.0

student churn = 0.1641
baseline churn = 0.1735
student accuracy = 0.7892
baseline accuracy = 0.8032
teacher accuracy = 0.7789
student wlr = 1.0823723077774048
baseline wlr = 1.4410163164138794
churn ratio = 0.945821325648415
student good_churn = 0.0657
student bad_churn = 0.0002
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.2

student churn = 0.1601
baseline churn = 0.1716
student accuracy = 0.8126
baseline accuracy = 0.8
teacher accuracy = 0.7825
student wlr = 1.770975112915039
baseline wlr = 1.2996575832366943
churn ratio = 0.932983682983683
student good_churn = 0.0781
student bad_churn = 0.0002
baseline good_churn = 0.0002
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.4

student churn = 0.1479
baseline churn = 0.1716
student accuracy = 0.8092
baseline accuracy = 0.8
teacher accuracy = 0.7825
student wlr = 1.735576868057251
baseline wlr = 1.2996575832366943
churn ratio = 0.8618881118881119
student good_churn = 0.0722
student bad_churn = 0.0002
baseline good_churn = 0.0002
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.6

student churn = 0.1402
baseline churn = 0.1716
student accuracy = 0.804
baseline accuracy = 0.8
teacher accuracy = 0.7825
student wlr = 1.563380241394043
baseline wlr = 1.2996575832366943
churn ratio = 0.8170163170163169
student good_churn = 0.0666
student bad_churn = 0.0001
baseline good_churn = 0.0002
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.8

student churn = 0.123
baseline churn = 0.1716
student accuracy = 0.7948
baseline accuracy = 0.8
teacher accuracy = 0.7825
student wlr = 1.3798449039459229
baseline wlr = 1.2996575832366943
churn ratio = 0.7167832167832168
student good_churn = 0.0534
student bad_churn = 0.0
baseline good_churn = 0.0002
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Epoch : 0
Training loss: 1.8165, Training accuracy: 32.9854, Val loss: 1.6665, Val accuracy: 38.8944, 

Epoch : 1
Training loss: 1.5529, Training accuracy: 43.3599, Val loss: 1.5179, Val accuracy: 45.6092, 

Epoch : 2
Training loss: 1.3531, Training accuracy: 51.9116, Val loss: 1.3434, Val accuracy: 52.7987, 

Epoch : 3
Training loss: 1.2243, Training accuracy: 56.8418, Val loss: 1.2257, Val accuracy: 57.1005, 

Epoch : 4
Training loss: 1.1324, Training accuracy: 60.0776, Val loss: 1.1454, Val accuracy: 59.7706, 

Epoch : 5
Training loss: 1.0737, Training accuracy: 62.4523, Val loss: 1.0717, Val accuracy: 62.6187, 

Epoch : 6
Training loss: 0.9969, Training accuracy: 65.3125, Val loss: 1.0406, Val accuracy: 63.5186, 

Epoch : 7
Training loss: 0.9527, Training accuracy: 66.7852, Val loss: 1.0137, Val accuracy: 65.1009, 

Epoch : 8
Training loss: 0.8929, Training accuracy: 69.0780, Val loss: 0.9323, Val accuracy: 68.0973, 

Epoch : 9
Training loss: 0.8625, Training accuracy: 70.1086, Val

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8340, Training accuracy: 32.3149, Val loss: 1.7060, Val accuracy: 37.0352, 

Epoch : 1
Training loss: 1.5535, Training accuracy: 43.7788, Val loss: 1.4946, Val accuracy: 45.3323, 

Epoch : 2
Training loss: 1.3944, Training accuracy: 50.2881, Val loss: 1.3019, Val accuracy: 51.9086, 

Epoch : 3
Training loss: 1.2450, Training accuracy: 55.6549, Val loss: 1.2637, Val accuracy: 55.3896, 

Epoch : 4
Training loss: 1.1433, Training accuracy: 60.2150, Val loss: 1.1640, Val accuracy: 58.6926, 

Epoch : 5
Training loss: 1.0764, Training accuracy: 62.6474, Val loss: 1.1115, Val accuracy: 61.6396, 

Epoch : 6
Training loss: 1.0015, Training accuracy: 65.1008, Val loss: 1.0207, Val accuracy: 64.1911, 

Epoch : 7
Training loss: 0.9452, Training accuracy: 67.5155, Val loss: 0.9842, Val accuracy: 66.0700, 

Epoch : 8
Training loss: 0.8997, Training accuracy: 69.1567, Val loss: 0.9471, Val accuracy: 67.3161, 

Epoch : 9
Training loss: 0.8646, Training accuracy: 70.4865, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8370, Training accuracy: 32.2407, Val loss: 1.6434, Val accuracy: 39.1119, 

Epoch : 1
Training loss: 1.5617, Training accuracy: 43.6425, Val loss: 1.4768, Val accuracy: 46.1036, 

Epoch : 2
Training loss: 1.4014, Training accuracy: 49.9911, Val loss: 1.3727, Val accuracy: 50.9691, 

Epoch : 3
Training loss: 1.2716, Training accuracy: 55.0964, Val loss: 1.2170, Val accuracy: 55.5973, 

Epoch : 4
Training loss: 1.1817, Training accuracy: 58.8819, Val loss: 1.1706, Val accuracy: 58.3366, 

Epoch : 5
Training loss: 1.1416, Training accuracy: 60.2770, Val loss: 1.1223, Val accuracy: 60.8386, 

Epoch : 6
Training loss: 1.0629, Training accuracy: 63.2868, Val loss: 1.0418, Val accuracy: 64.0131, 

Epoch : 7
Training loss: 0.9890, Training accuracy: 66.2367, Val loss: 1.0195, Val accuracy: 64.5075, 

Epoch : 8
Training loss: 0.9594, Training accuracy: 67.8158, Val loss: 0.9597, Val accuracy: 66.8216, 

Epoch : 9
Training loss: 0.9230, Training accuracy: 68.7201, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8310, Training accuracy: 32.7316, Val loss: 1.7879, Val accuracy: 35.9968, 

Epoch : 1
Training loss: 1.5553, Training accuracy: 44.0392, Val loss: 1.4889, Val accuracy: 45.8564, 

Epoch : 2
Training loss: 1.3817, Training accuracy: 51.7099, Val loss: 1.3172, Val accuracy: 53.5799, 

Epoch : 3
Training loss: 1.2478, Training accuracy: 56.7775, Val loss: 1.1884, Val accuracy: 57.9510, 

Epoch : 4
Training loss: 1.1559, Training accuracy: 60.2305, Val loss: 1.1148, Val accuracy: 60.4430, 

Epoch : 5
Training loss: 1.0675, Training accuracy: 63.8675, Val loss: 1.0559, Val accuracy: 62.9252, 

Epoch : 6
Training loss: 1.0192, Training accuracy: 65.7513, Val loss: 0.9647, Val accuracy: 66.2184, 

Epoch : 7
Training loss: 0.9526, Training accuracy: 68.3743, Val loss: 0.9663, Val accuracy: 66.1689, 
Epoch : 8
Training loss: 0.9229, Training accuracy: 69.5301, Val loss: 0.9320, Val accuracy: 67.3161, 

Epoch : 9
Training loss: 0.8863, Training accuracy: 71.0594, Val loss: 0.89

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8452, Training accuracy: 32.3061, Val loss: 1.7057, Val accuracy: 38.3208, 

Epoch : 1
Training loss: 1.5759, Training accuracy: 43.4065, Val loss: 1.5071, Val accuracy: 43.8786, 

Epoch : 2
Training loss: 1.4213, Training accuracy: 49.8593, Val loss: 1.3612, Val accuracy: 51.6911, 

Epoch : 3
Training loss: 1.2927, Training accuracy: 55.6627, Val loss: 1.2117, Val accuracy: 56.5566, 

Epoch : 4
Training loss: 1.1886, Training accuracy: 59.5423, Val loss: 1.1976, Val accuracy: 57.8323, 

Epoch : 5
Training loss: 1.1157, Training accuracy: 62.2606, Val loss: 1.0435, Val accuracy: 63.4098, 

Epoch : 6
Training loss: 1.0415, Training accuracy: 65.1729, Val loss: 0.9840, Val accuracy: 65.5953, 

Epoch : 7
Training loss: 0.9890, Training accuracy: 67.3570, Val loss: 0.9705, Val accuracy: 66.2480, 

Epoch : 8
Training loss: 0.9424, Training accuracy: 68.8985, Val loss: 0.9093, Val accuracy: 68.6610, 

Epoch : 9
Training loss: 0.8994, Training accuracy: 70.8588, Val loss: 0.8

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7714, Training accuracy: 34.9241, Val loss: 1.5934, Val accuracy: 42.4248, 

Epoch : 1
Training loss: 1.4676, Training accuracy: 47.0497, Val loss: 1.4376, Val accuracy: 48.4474, 

Epoch : 2
Training loss: 1.2649, Training accuracy: 55.0569, Val loss: 1.2487, Val accuracy: 54.9446, 

Epoch : 3
Training loss: 1.1340, Training accuracy: 60.2935, Val loss: 1.1294, Val accuracy: 60.2749, 

Epoch : 4
Training loss: 1.0359, Training accuracy: 63.8254, Val loss: 0.9959, Val accuracy: 65.4865, 

Epoch : 5
Training loss: 0.9570, Training accuracy: 66.7482, Val loss: 0.9662, Val accuracy: 66.9996, 

Epoch : 6
Training loss: 0.8871, Training accuracy: 69.2168, Val loss: 0.9041, Val accuracy: 68.4632, 

Epoch : 7
Training loss: 0.8343, Training accuracy: 71.1736, Val loss: 0.8852, Val accuracy: 69.8279, 

Epoch : 8
Training loss: 0.8126, Training accuracy: 71.8001, Val loss: 0.8925, Val accuracy: 69.4917, 
Epoch : 9
Training loss: 0.7535, Training accuracy: 74.2387, Val loss: 0.77

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7617, Training accuracy: 35.3459, Val loss: 1.5641, Val accuracy: 41.6634, 

Epoch : 1
Training loss: 1.4247, Training accuracy: 48.6047, Val loss: 1.3514, Val accuracy: 51.2263, 

Epoch : 2
Training loss: 1.2161, Training accuracy: 56.8765, Val loss: 1.1725, Val accuracy: 58.0597, 

Epoch : 3
Training loss: 1.0587, Training accuracy: 62.7671, Val loss: 1.0747, Val accuracy: 62.2033, 

Epoch : 4
Training loss: 0.9663, Training accuracy: 66.4412, Val loss: 0.9735, Val accuracy: 66.3172, 

Epoch : 5
Training loss: 0.8793, Training accuracy: 69.2242, Val loss: 0.9275, Val accuracy: 67.8501, 

Epoch : 6
Training loss: 0.8155, Training accuracy: 71.8800, Val loss: 0.8690, Val accuracy: 69.9466, 

Epoch : 7
Training loss: 0.7698, Training accuracy: 73.5623, Val loss: 0.8571, Val accuracy: 70.0257, 

Epoch : 8
Training loss: 0.7391, Training accuracy: 74.6630, Val loss: 0.7734, Val accuracy: 73.3683, 

Epoch : 9
Training loss: 0.7016, Training accuracy: 76.3428, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7367, Training accuracy: 36.1896, Val loss: 1.5940, Val accuracy: 40.8722, 

Epoch : 1
Training loss: 1.3837, Training accuracy: 49.9651, Val loss: 1.3712, Val accuracy: 51.0483, 

Epoch : 2
Training loss: 1.1862, Training accuracy: 57.6003, Val loss: 1.1493, Val accuracy: 59.5036, 

Epoch : 3
Training loss: 1.0410, Training accuracy: 63.5708, Val loss: 1.0586, Val accuracy: 62.8362, 

Epoch : 4
Training loss: 0.9452, Training accuracy: 66.7182, Val loss: 1.0080, Val accuracy: 64.7646, 

Epoch : 5
Training loss: 0.8915, Training accuracy: 68.8049, Val loss: 0.9688, Val accuracy: 66.8908, 

Epoch : 6
Training loss: 0.8161, Training accuracy: 71.9524, Val loss: 0.8832, Val accuracy: 69.9367, 

Epoch : 7
Training loss: 0.7835, Training accuracy: 72.9608, Val loss: 0.8278, Val accuracy: 71.5091, 

Epoch : 8
Training loss: 0.7389, Training accuracy: 74.5058, Val loss: 0.8182, Val accuracy: 71.7662, 

Epoch : 9
Training loss: 0.6973, Training accuracy: 76.2555, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7132, Training accuracy: 37.5474, Val loss: 1.5563, Val accuracy: 43.5324, 

Epoch : 1
Training loss: 1.3690, Training accuracy: 50.5891, Val loss: 1.2884, Val accuracy: 53.0953, 

Epoch : 2
Training loss: 1.1423, Training accuracy: 59.4199, Val loss: 1.1357, Val accuracy: 60.6804, 

Epoch : 3
Training loss: 0.9933, Training accuracy: 64.8537, Val loss: 0.9736, Val accuracy: 66.3172, 

Epoch : 4
Training loss: 0.8878, Training accuracy: 68.7924, Val loss: 0.9177, Val accuracy: 68.6313, 

Epoch : 5
Training loss: 0.8159, Training accuracy: 71.3958, Val loss: 0.8431, Val accuracy: 70.9157, 

Epoch : 6
Training loss: 0.7664, Training accuracy: 73.1579, Val loss: 0.8202, Val accuracy: 71.9146, 

Epoch : 7
Training loss: 0.7263, Training accuracy: 75.0948, Val loss: 0.7881, Val accuracy: 73.3683, 

Epoch : 8
Training loss: 0.6984, Training accuracy: 75.9160, Val loss: 0.7596, Val accuracy: 73.8133, 

Epoch : 9
Training loss: 0.6750, Training accuracy: 77.0792, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.13.1


Training loss: 1.7872, Training accuracy: 34.2402, Val loss: 1.6479, Val accuracy: 40.5657, 

1
Training loss: 1.4687, Training accuracy: 46.8226, Val loss: 1.4142, Val accuracy: 48.6650, 

2
Training loss: 1.2821, Training accuracy: 54.2657, Val loss: 1.2143, Val accuracy: 56.5269, 

3
Training loss: 1.1484, Training accuracy: 59.4624, Val loss: 1.1751, Val accuracy: 59.1871, 

4
Training loss: 1.0891, Training accuracy: 61.8236, Val loss: 1.0403, Val accuracy: 63.1032, 

5
Training loss: 0.9904, Training accuracy: 65.7398, Val loss: 1.0250, Val accuracy: 64.5372, 

6
Training loss: 0.9168, Training accuracy: 68.0886, Val loss: 0.9333, Val accuracy: 67.3457, 

7
Training loss: 0.8630, Training accuracy: 69.9656, Val loss: 0.8885, Val accuracy: 68.6511, 

8
Training loss: 0.8170, Training accuracy: 71.5081, Val loss: 0.8596, Val accuracy: 70.7773, 

9
Training loss: 0.7925, Training accuracy: 72.7386, Val loss: 0.8811, Val accuracy: 69.5016, 
10
Training loss: 0.7632, Training accuracy

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.2



Epsilon = 1.0

student churn = 0.1777
baseline churn = 0.1709
student accuracy = 0.7724
baseline accuracy = 0.8125
teacher accuracy = 0.7905
student wlr = 0.8114973306655884
baseline wlr = 1.3992739915847778
churn ratio = 1.039789350497367
student good_churn = 0.0607
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.4



Epsilon = 1.0

student churn = 0.1657
baseline churn = 0.1709
student accuracy = 0.7841
baseline accuracy = 0.8125
teacher accuracy = 0.7905
student wlr = 0.9282442927360535
baseline wlr = 1.3992739915847778
churn ratio = 0.9695728496196606
student good_churn = 0.0608
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.6



Epsilon = 1.0

student churn = 0.1662
baseline churn = 0.1709
student accuracy = 0.7884
baseline accuracy = 0.8125
teacher accuracy = 0.7905
student wlr = 0.9815950989723206
baseline wlr = 1.3992739915847778
churn ratio = 0.9724985371562317
student good_churn = 0.064
student bad_churn = 0.0002
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.8



Epsilon = 1.0

student churn = 0.1595
baseline churn = 0.1709
student accuracy = 0.7885
baseline accuracy = 0.8125
teacher accuracy = 0.7905
student wlr = 0.9566613435745239
baseline wlr = 1.3992739915847778
churn ratio = 0.9332943241661791
student good_churn = 0.0596
student bad_churn = 0.0001
baseline good_churn = 0.0001
baseline bad_churn = 0.0002


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.2

student churn = 0.1597
baseline churn = 0.1699
student accuracy = 0.81
baseline accuracy = 0.8066
teacher accuracy = 0.7882
student wlr = 1.4038095474243164
baseline wlr = 1.3291592597961426
churn ratio = 0.9399646851088876
student good_churn = 0.0737
student bad_churn = 0.0003
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.4

student churn = 0.1485
baseline churn = 0.1699
student accuracy = 0.8104
baseline accuracy = 0.8066
teacher accuracy = 0.7882
student wlr = 1.4453781843185425
baseline wlr = 1.3291592597961426
churn ratio = 0.874043555032372
student good_churn = 0.0688
student bad_churn = 0.0001
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.6

student churn = 0.1338
baseline churn = 0.1699
student accuracy = 0.8037
baseline accuracy = 0.8066
teacher accuracy = 0.7882
student wlr = 1.3075170516967773
baseline wlr = 1.3291592597961426
churn ratio = 0.7875220718069453
student good_churn = 0.0574
student bad_churn = 0.0001
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.8

student churn = 0.1292
baseline churn = 0.1699
student accuracy = 0.7982
baseline accuracy = 0.8066
teacher accuracy = 0.7882
student wlr = 1.2305986881256104
baseline wlr = 1.3291592597961426
churn ratio = 0.7604473219540907
student good_churn = 0.0555
student bad_churn = 0.0001
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Epoch : 0
Training loss: 1.8177, Training accuracy: 33.3921, Val loss: 1.6588, Val accuracy: 39.1021, 

Epoch : 1
Training loss: 1.5514, Training accuracy: 44.0481, Val loss: 1.4820, Val accuracy: 46.8750, 

Epoch : 2
Training loss: 1.3856, Training accuracy: 50.2748, Val loss: 1.3414, Val accuracy: 52.5415, 

Epoch : 3
Training loss: 1.2530, Training accuracy: 55.5685, Val loss: 1.2140, Val accuracy: 56.9917, 

Epoch : 4
Training loss: 1.1586, Training accuracy: 59.1933, Val loss: 1.1236, Val accuracy: 60.3343, 

Epoch : 5
Training loss: 1.0907, Training accuracy: 61.8828, Val loss: 1.0968, Val accuracy: 61.6594, 

Epoch : 6
Training loss: 1.0119, Training accuracy: 64.3218, Val loss: 1.0661, Val accuracy: 63.7955, 

Epoch : 7
Training loss: 0.9689, Training accuracy: 65.9530, Val loss: 1.0175, Val accuracy: 64.8042, 

Epoch : 8
Training loss: 0.9126, Training accuracy: 68.1715, Val loss: 0.9915, Val accuracy: 65.7338, 

Epoch : 9
Training loss: 0.8741, Training accuracy: 69.1279, Val

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8581, Training accuracy: 31.1026, Val loss: 1.6675, Val accuracy: 37.9648, 

Epoch : 1
Training loss: 1.5689, Training accuracy: 43.1250, Val loss: 1.4779, Val accuracy: 46.4597, 

Epoch : 2
Training loss: 1.3968, Training accuracy: 50.2261, Val loss: 1.3399, Val accuracy: 51.7801, 

Epoch : 3
Training loss: 1.2489, Training accuracy: 55.5042, Val loss: 1.2220, Val accuracy: 56.3885, 

Epoch : 4
Training loss: 1.1479, Training accuracy: 59.7363, Val loss: 1.1577, Val accuracy: 58.8608, 

Epoch : 5
Training loss: 1.0784, Training accuracy: 62.6341, Val loss: 1.0901, Val accuracy: 61.5012, 

Epoch : 6
Training loss: 1.0137, Training accuracy: 64.8016, Val loss: 1.0286, Val accuracy: 63.8054, 

Epoch : 7
Training loss: 0.9513, Training accuracy: 67.2717, Val loss: 0.9867, Val accuracy: 65.8030, 

Epoch : 8
Training loss: 0.9244, Training accuracy: 68.3090, Val loss: 0.9415, Val accuracy: 67.1578, 

Epoch : 9
Training loss: 0.8728, Training accuracy: 70.0676, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8279, Training accuracy: 32.6152, Val loss: 1.7038, Val accuracy: 37.1242, 

Epoch : 1
Training loss: 1.5711, Training accuracy: 43.2513, Val loss: 1.4967, Val accuracy: 45.0455, 

Epoch : 2
Training loss: 1.4011, Training accuracy: 50.6316, Val loss: 1.3926, Val accuracy: 49.9209, 

Epoch : 3
Training loss: 1.2662, Training accuracy: 55.4532, Val loss: 1.2623, Val accuracy: 55.0534, 

Epoch : 4
Training loss: 1.1737, Training accuracy: 59.1090, Val loss: 1.1348, Val accuracy: 60.3639, 

Epoch : 5
Training loss: 1.1047, Training accuracy: 61.6212, Val loss: 1.1155, Val accuracy: 59.9585, 
Epoch : 6
Training loss: 1.0339, Training accuracy: 64.4249, Val loss: 1.0246, Val accuracy: 63.7164, 

Epoch : 7
Training loss: 0.9851, Training accuracy: 66.2378, Val loss: 1.0241, Val accuracy: 64.2010, 

Epoch : 8
Training loss: 0.9472, Training accuracy: 68.0829, Val loss: 0.9785, Val accuracy: 65.7832, 

Epoch : 9
Training loss: 0.8890, Training accuracy: 70.0632, Val loss: 0.92

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8146, Training accuracy: 33.6458, Val loss: 1.6390, Val accuracy: 40.7239, 

Epoch : 1
Training loss: 1.5400, Training accuracy: 44.5844, Val loss: 1.4088, Val accuracy: 48.3089, 

Epoch : 2
Training loss: 1.3611, Training accuracy: 51.9215, Val loss: 1.3590, Val accuracy: 50.7219, 

Epoch : 3
Training loss: 1.2469, Training accuracy: 56.6633, Val loss: 1.2014, Val accuracy: 57.4070, 

Epoch : 4
Training loss: 1.1439, Training accuracy: 60.5020, Val loss: 1.1421, Val accuracy: 59.5629, 

Epoch : 5
Training loss: 1.0673, Training accuracy: 63.6536, Val loss: 1.0490, Val accuracy: 63.1329, 

Epoch : 6
Training loss: 1.0153, Training accuracy: 65.7037, Val loss: 0.9914, Val accuracy: 65.0514, 

Epoch : 7
Training loss: 0.9757, Training accuracy: 66.9481, Val loss: 0.9810, Val accuracy: 65.0811, 

Epoch : 8
Training loss: 0.9200, Training accuracy: 69.3362, Val loss: 0.9147, Val accuracy: 68.2061, 

Epoch : 9
Training loss: 0.8813, Training accuracy: 70.5918, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8344, Training accuracy: 32.3870, Val loss: 1.6748, Val accuracy: 37.9747, 

Epoch : 1
Training loss: 1.5800, Training accuracy: 43.0086, Val loss: 1.4380, Val accuracy: 47.2607, 

Epoch : 2
Training loss: 1.3978, Training accuracy: 50.3424, Val loss: 1.3568, Val accuracy: 50.1187, 

Epoch : 3
Training loss: 1.2851, Training accuracy: 54.8936, Val loss: 1.2731, Val accuracy: 54.7666, 

Epoch : 4
Training loss: 1.1975, Training accuracy: 58.4497, Val loss: 1.1587, Val accuracy: 59.2860, 

Epoch : 5
Training loss: 1.1168, Training accuracy: 62.1698, Val loss: 1.1168, Val accuracy: 60.7595, 

Epoch : 6
Training loss: 1.0539, Training accuracy: 64.5800, Val loss: 1.0576, Val accuracy: 62.7472, 

Epoch : 7
Training loss: 1.0099, Training accuracy: 66.3774, Val loss: 1.0277, Val accuracy: 64.2405, 

Epoch : 8
Training loss: 0.9518, Training accuracy: 68.5838, Val loss: 0.9448, Val accuracy: 67.1282, 

Epoch : 9
Training loss: 0.9095, Training accuracy: 70.2460, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7418, Training accuracy: 36.3968, Val loss: 1.5534, Val accuracy: 44.6104, 

Epoch : 1
Training loss: 1.4205, Training accuracy: 48.9392, Val loss: 1.3271, Val accuracy: 51.9086, 

Epoch : 2
Training loss: 1.2306, Training accuracy: 56.2700, Val loss: 1.1462, Val accuracy: 60.5320, 

Epoch : 3
Training loss: 1.1044, Training accuracy: 60.9200, Val loss: 1.0714, Val accuracy: 63.0340, 

Epoch : 4
Training loss: 1.0029, Training accuracy: 64.7314, Val loss: 1.0520, Val accuracy: 62.7769, 
Epoch : 5
Training loss: 0.9272, Training accuracy: 67.4920, Val loss: 0.9889, Val accuracy: 66.1689, 

Epoch : 6
Training loss: 0.8707, Training accuracy: 69.4688, Val loss: 0.8904, Val accuracy: 69.5807, 

Epoch : 7
Training loss: 0.8187, Training accuracy: 71.6479, Val loss: 0.8520, Val accuracy: 70.8465, 

Epoch : 8
Training loss: 0.7887, Training accuracy: 72.8035, Val loss: 0.8262, Val accuracy: 71.6772, 

Epoch : 9
Training loss: 0.7583, Training accuracy: 73.8144, Val loss: 0.82

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7276, Training accuracy: 36.9434, Val loss: 1.5441, Val accuracy: 42.6028, 

Epoch : 1
Training loss: 1.3872, Training accuracy: 49.8827, Val loss: 1.3221, Val accuracy: 52.7591, 

Epoch : 2
Training loss: 1.1733, Training accuracy: 58.2892, Val loss: 1.1510, Val accuracy: 59.3157, 

Epoch : 3
Training loss: 1.0389, Training accuracy: 63.6631, Val loss: 1.0945, Val accuracy: 61.6990, 

Epoch : 4
Training loss: 0.9508, Training accuracy: 66.9554, Val loss: 0.9613, Val accuracy: 66.1294, 

Epoch : 5
Training loss: 0.8883, Training accuracy: 69.0246, Val loss: 0.9460, Val accuracy: 67.4644, 

Epoch : 6
Training loss: 0.8195, Training accuracy: 71.6504, Val loss: 0.8748, Val accuracy: 69.4324, 

Epoch : 7
Training loss: 0.7726, Training accuracy: 73.1305, Val loss: 0.8501, Val accuracy: 71.0542, 

Epoch : 8
Training loss: 0.7338, Training accuracy: 75.0549, Val loss: 0.8000, Val accuracy: 72.4782, 

Epoch : 9
Training loss: 0.7047, Training accuracy: 76.2705, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7212, Training accuracy: 36.6489, Val loss: 1.5038, Val accuracy: 46.1036, 

Epoch : 1
Training loss: 1.3696, Training accuracy: 50.5791, Val loss: 1.2841, Val accuracy: 53.7975, 

Epoch : 2
Training loss: 1.1517, Training accuracy: 59.0905, Val loss: 1.1078, Val accuracy: 61.2737, 

Epoch : 3
Training loss: 1.0223, Training accuracy: 63.4635, Val loss: 1.0094, Val accuracy: 64.7646, 

Epoch : 4
Training loss: 0.9238, Training accuracy: 67.4196, Val loss: 0.9745, Val accuracy: 65.9711, 

Epoch : 5
Training loss: 0.8809, Training accuracy: 69.3141, Val loss: 0.9274, Val accuracy: 68.1764, 

Epoch : 6
Training loss: 0.8147, Training accuracy: 71.8326, Val loss: 0.8478, Val accuracy: 71.1828, 

Epoch : 7
Training loss: 0.7676, Training accuracy: 73.7170, Val loss: 0.8357, Val accuracy: 71.6673, 

Epoch : 8
Training loss: 0.7402, Training accuracy: 74.5956, Val loss: 0.7960, Val accuracy: 72.9134, 

Epoch : 9
Training loss: 0.7096, Training accuracy: 75.5841, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.7101, Training accuracy: 36.8535, Val loss: 1.5117, Val accuracy: 45.9652, 

Epoch : 1
Training loss: 1.3690, Training accuracy: 50.4917, Val loss: 1.3302, Val accuracy: 53.4217, 

Epoch : 2
Training loss: 1.1380, Training accuracy: 59.2053, Val loss: 1.0972, Val accuracy: 61.7583, 

Epoch : 3
Training loss: 0.9940, Training accuracy: 64.7439, Val loss: 1.0263, Val accuracy: 64.7844, 

Epoch : 4
Training loss: 0.8941, Training accuracy: 68.0711, Val loss: 0.9041, Val accuracy: 69.3335, 

Epoch : 5
Training loss: 0.8138, Training accuracy: 71.2310, Val loss: 0.8757, Val accuracy: 69.7389, 

Epoch : 6
Training loss: 0.7653, Training accuracy: 73.1929, Val loss: 0.8273, Val accuracy: 71.1531, 

Epoch : 7
Training loss: 0.7272, Training accuracy: 74.5482, Val loss: 0.8042, Val accuracy: 72.5969, 

Epoch : 8
Training loss: 0.6894, Training accuracy: 75.9435, Val loss: 0.7762, Val accuracy: 73.6650, 

Epoch : 9
Training loss: 0.6682, Training accuracy: 77.1291, Val loss: 0.7

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.13.1


Training loss: 1.7740, Training accuracy: 35.1113, Val loss: 1.6028, Val accuracy: 40.1701, 

1
Training loss: 1.4594, Training accuracy: 47.2918, Val loss: 1.3890, Val accuracy: 49.1001, 

2
Training loss: 1.2607, Training accuracy: 54.8997, Val loss: 1.2454, Val accuracy: 55.2116, 

3
Training loss: 1.1491, Training accuracy: 59.5822, Val loss: 1.1135, Val accuracy: 61.0166, 

4
Training loss: 1.0405, Training accuracy: 63.5358, Val loss: 1.0873, Val accuracy: 62.8857, 

5
Training loss: 0.9663, Training accuracy: 66.5410, Val loss: 0.9643, Val accuracy: 66.3667, 

6
Training loss: 0.8867, Training accuracy: 69.4564, Val loss: 0.9582, Val accuracy: 66.7227, 

7
Training loss: 0.8459, Training accuracy: 70.7768, Val loss: 0.9021, Val accuracy: 68.8093, 

8
Training loss: 0.8079, Training accuracy: 72.1695, Val loss: 0.8762, Val accuracy: 69.8873, 

9
Training loss: 0.7779, Training accuracy: 73.1505, Val loss: 0.8670, Val accuracy: 70.9454, 

10
Training loss: 0.7308, Training accurac

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.2



Epsilon = 1.0

student churn = 0.1731
baseline churn = 0.1713
student accuracy = 0.7805
baseline accuracy = 0.8091
teacher accuracy = 0.786
student wlr = 0.9209370613098145
baseline wlr = 1.4342105388641357
churn ratio = 1.0105078809106829
student good_churn = 0.0629
student bad_churn = 0.0001
baseline good_churn = 0.0002
baseline bad_churn = 0.0004


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.4



Epsilon = 1.0

student churn = 0.1707
baseline churn = 0.1713
student accuracy = 0.7846
baseline accuracy = 0.8091
teacher accuracy = 0.786
student wlr = 1.029687523841858
baseline wlr = 1.4342105388641357
churn ratio = 0.9964973730297723
student good_churn = 0.0659
student bad_churn = 0.0
baseline good_churn = 0.0002
baseline bad_churn = 0.0004


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.6



Epsilon = 1.0

student churn = 0.1664
baseline churn = 0.1713
student accuracy = 0.7817
baseline accuracy = 0.8091
teacher accuracy = 0.786
student wlr = 0.9263157844543457
baseline wlr = 1.4342105388641357
churn ratio = 0.9713952130764739
student good_churn = 0.0616
student bad_churn = 0.0002
baseline good_churn = 0.0002
baseline bad_churn = 0.0004


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




alpha = 0.8



Epsilon = 1.0

student churn = 0.1748
baseline churn = 0.1713
student accuracy = 0.7772
baseline accuracy = 0.8091
teacher accuracy = 0.786
student wlr = 0.9553313851356506
baseline wlr = 1.4342105388641357
churn ratio = 1.0204319906596615
student good_churn = 0.0663
student bad_churn = 0.0002
baseline good_churn = 0.0002
baseline bad_churn = 0.0004


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.2

student churn = 0.1607
baseline churn = 0.17
student accuracy = 0.8052
baseline accuracy = 0.8096
teacher accuracy = 0.7838
student wlr = 1.3802281618118286
baseline wlr = 1.4886363744735718
churn ratio = 0.9452941176470588
student good_churn = 0.0726
student bad_churn = 0.0001
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.4

student churn = 0.149
baseline churn = 0.17
student accuracy = 0.809
baseline accuracy = 0.8096
teacher accuracy = 0.7838
student wlr = 1.617511510848999
baseline wlr = 1.4886363744735718
churn ratio = 0.876470588235294
student good_churn = 0.0702
student bad_churn = 0.0002
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.6

student churn = 0.1406
baseline churn = 0.17
student accuracy = 0.8013
baseline accuracy = 0.8096
teacher accuracy = 0.7838
student wlr = 1.4105960130691528
baseline wlr = 1.4886363744735718
churn ratio = 0.8270588235294117
student good_churn = 0.0639
student bad_churn = 0.0
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0




lamda = 0.8

student churn = 0.1278
baseline churn = 0.17
student accuracy = 0.7995
baseline accuracy = 0.8096
teacher accuracy = 0.7838
student wlr = 1.3734643459320068
baseline wlr = 1.4886363744735718
churn ratio = 0.7517647058823529
student good_churn = 0.0559
student bad_churn = 0.0001
baseline good_churn = 0.0
baseline bad_churn = 0.0001


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Epoch : 0
Training loss: 1.8314, Training accuracy: 32.3404, Val loss: 1.6966, Val accuracy: 36.5605, 

Epoch : 1
Training loss: 1.5472, Training accuracy: 43.7245, Val loss: 1.4759, Val accuracy: 46.3014, 

Epoch : 2
Training loss: 1.3676, Training accuracy: 50.8754, Val loss: 1.3462, Val accuracy: 51.9482, 

Epoch : 3
Training loss: 1.2426, Training accuracy: 55.9286, Val loss: 1.2341, Val accuracy: 56.2401, 

Epoch : 4
Training loss: 1.1425, Training accuracy: 59.9080, Val loss: 1.1136, Val accuracy: 61.3034, 

Epoch : 5
Training loss: 1.0646, Training accuracy: 63.0452, Val loss: 1.0783, Val accuracy: 62.2330, 

Epoch : 6
Training loss: 1.0252, Training accuracy: 64.2797, Val loss: 1.1144, Val accuracy: 61.2836, 
Epoch : 7
Training loss: 0.9534, Training accuracy: 66.7952, Val loss: 1.0006, Val accuracy: 65.3877, 

Epoch : 8
Training loss: 0.9027, Training accuracy: 68.5173, Val loss: 0.9237, Val accuracy: 67.4842, 

Epoch : 9
Training loss: 0.8714, Training accuracy: 69.7584, Val 

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8465, Training accuracy: 31.5869, Val loss: 1.6250, Val accuracy: 39.0922, 

Epoch : 1
Training loss: 1.5772, Training accuracy: 42.7460, Val loss: 1.5267, Val accuracy: 44.0368, 

Epoch : 2
Training loss: 1.4252, Training accuracy: 49.0193, Val loss: 1.3409, Val accuracy: 51.6218, 

Epoch : 3
Training loss: 1.2850, Training accuracy: 54.6554, Val loss: 1.2118, Val accuracy: 57.8224, 

Epoch : 4
Training loss: 1.1705, Training accuracy: 59.1090, Val loss: 1.1739, Val accuracy: 58.9893, 

Epoch : 5
Training loss: 1.0807, Training accuracy: 62.6939, Val loss: 1.1224, Val accuracy: 60.8584, 

Epoch : 6
Training loss: 1.0152, Training accuracy: 65.4311, Val loss: 0.9963, Val accuracy: 65.0119, 

Epoch : 7
Training loss: 0.9592, Training accuracy: 67.1664, Val loss: 0.9850, Val accuracy: 65.3184, 

Epoch : 8
Training loss: 0.9141, Training accuracy: 68.8198, Val loss: 0.9454, Val accuracy: 67.6523, 

Epoch : 9
Training loss: 0.8664, Training accuracy: 70.4610, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8247, Training accuracy: 32.8480, Val loss: 1.6441, Val accuracy: 37.8659, 

Epoch : 1
Training loss: 1.5329, Training accuracy: 44.3163, Val loss: 1.4330, Val accuracy: 48.5067, 

Epoch : 2
Training loss: 1.3530, Training accuracy: 51.6611, Val loss: 1.3213, Val accuracy: 52.3339, 

Epoch : 3
Training loss: 1.2226, Training accuracy: 57.6164, Val loss: 1.2473, Val accuracy: 56.5763, 

Epoch : 4
Training loss: 1.1340, Training accuracy: 61.1370, Val loss: 1.1322, Val accuracy: 60.7199, 

Epoch : 5
Training loss: 1.0448, Training accuracy: 64.4459, Val loss: 1.0518, Val accuracy: 62.7373, 

Epoch : 6
Training loss: 1.0075, Training accuracy: 65.8688, Val loss: 0.9995, Val accuracy: 65.2393, 

Epoch : 7
Training loss: 0.9447, Training accuracy: 68.2535, Val loss: 0.9638, Val accuracy: 66.1986, 

Epoch : 8
Training loss: 0.9430, Training accuracy: 68.3500, Val loss: 0.9444, Val accuracy: 67.0392, 

Epoch : 9
Training loss: 0.9062, Training accuracy: 69.8227, Val loss: 0.9

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Training loss: 1.8303, Training accuracy: 33.2702, Val loss: 1.6606, Val accuracy: 39.3097, 

Epoch : 1
Training loss: 1.5680, Training accuracy: 43.7677, Val loss: 1.5388, Val accuracy: 42.5534, 

Epoch : 2
Training loss: 1.4021, Training accuracy: 50.1496, Val loss: 1.3237, Val accuracy: 51.5131, 

Epoch : 3
Training loss: 1.2622, Training accuracy: 56.2555, Val loss: 1.2771, Val accuracy: 54.3414, 

Epoch : 4
Training loss: 1.1710, Training accuracy: 59.8969, Val loss: 1.1399, Val accuracy: 59.4937, 

Epoch : 5
Training loss: 1.0915, Training accuracy: 62.7914, Val loss: 1.0849, Val accuracy: 62.2231, 

Epoch : 6
Training loss: 1.0342, Training accuracy: 65.2083, Val loss: 1.0394, Val accuracy: 64.2306, 

Epoch : 7
Training loss: 0.9799, Training accuracy: 67.0512, Val loss: 0.9740, Val accuracy: 65.2097, 

Epoch : 8
Training loss: 0.9281, Training accuracy: 69.4060, Val loss: 0.9568, Val accuracy: 66.3074, 

Epoch : 9
Training loss: 0.8914, Training accuracy: 70.8976, Val loss: 0.9

In [None]:
while True:
  a=torch.Tensor([1,2,3,4,5])
  a.cuda()
  a=a*2

In [None]:
!pip3 install pyngrok

In [None]:

from pyngrok import ngrok
#Terminate open tunnels if exist
ngrok.kill()

#Setting the authtoken (optional)
#Get your authtoken from https://dashboard.ngrok.com/auth
NGROK_AUTH_TOKEN = "2HBNx1D7YdjGMMU5LPQQLxcTHeP_2GMdqwVRUMHyfftBwD9WV"
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

#Open an HTTPs tunnel on port 5000 for http://localhost:5000
ngrok_tunnel = ngrok.connect(addr="5000", proto="http", bind_tls=True)
print("MLflow Tracking UI:", ngrok_tunnel.public_url)



In [None]:
!mlflow ui

[2022-11-14 17:48:07 +0000] [58408] [INFO] Starting gunicorn 20.1.0
[2022-11-14 17:48:07 +0000] [58408] [INFO] Listening at: http://127.0.0.1:5000 (58408)
[2022-11-14 17:48:07 +0000] [58408] [INFO] Using worker: sync
[2022-11-14 17:48:07 +0000] [58411] [INFO] Booting worker with pid: 58411
