<a href="https://colab.research.google.com/github/eisbetterthanpi/vision/blob/main/Meta_Pseudo_Labels_down.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# Meta Pseudo Labels mar 2021 https://arxiv.org/pdf/2003.10580v4.pdf
# https://github.com/kekmodel/MPL-pytorch


In [None]:
# @title torch augment
# https://github.com/facebookresearch/vicreg/blob/main/augmentations.py
import torch
import torchvision.transforms as transforms

class TrainTransform(object):
    def __init__(self):
        # self.transform = transforms.RandomApply([transforms.Compose([
        self.transform = transforms.Compose([
                transforms.RandomPerspective(distortion_scale=0.3, p=0.5), # me
                # transforms.RandomResizedCrop((400,640), scale=(0.7, 1.0), ratio=(0.8, 1.25), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.RandomResizedCrop((32,32), scale=(0.7, 1.0), ratio=(0.8, 1.25), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5), # 0.5
                transforms.Lambda(lambda x : torch.clamp(x, 0., 1.)), # clamp else ColorJitter will return nan https://discuss.pytorch.org/t/input-is-nan-after-transformation/125455/6
                transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8,), # brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8
                transforms.RandomGrayscale(p=0.2), # 0.2
                # # transforms.RandomChoice(transforms.ColorJitter , transforms.RandomGrayscale(p=1.)
                transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),], p=1.0),
                # # transforms.RandomSolarize(threshold=130, p=0.5)
                transforms.RandomErasing(p=1., scale=(0.1, 0.11), ratio=(1,1), value=(0.485, 0.456, 0.406)),
                # transforms.ToTensor(), # ToTensored at dataset level, no need to ToTensor again
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalised at dataset level. default 0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225
                ])
            # ], p=1.)

    def __call__(self, sample):
        dims = len(sample.shape)
        if dims==3: x1 = self.transform(sample) # same transforms per minibatch
        elif dims==4: x1 = transforms.Lambda(lambda x: torch.stack([self.transform(x_) for x_ in x]))(sample) # diff transforms per img in minibatch
        # x1 = self.transform(sample)
        return x1

trs=TrainTransform()


In [None]:
# @title utils
# https://github.com/kekmodel/MPL-pytorch/blob/main/utils.py
import os
from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F


# def reduce_tensor(tensor, n):
#     rt = tensor.clone()
#     dist.all_reduce(rt, op=dist.ReduceOp.SUM)
#     rt /= n
#     return rt


def create_loss_fn():
    label_smoothing = 0 # default 0 / mainargs 0.15

    # if label_smoothing > 0:
    #     criterion = SmoothCrossEntropyV2(alpha=label_smoothing)
    # else:
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    return criterion.to(device)


def module_load_state_dict(model, state_dict):
    try:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    except:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = f'module.{k}'  # add `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

def model_load_state_dict(model, state_dict):
    try: model.load_state_dict(state_dict)
    except: module_load_state_dict(model, state_dict)


import shutil
# './checkpoint/model_best.pth.tar
# def save_checkpoint(state, is_best, finetune=False):
def save_checkpoint(name,save_path , state, is_best, finetune=False):
    save_path = './checkpoint'
    os.makedirs(save_path, exist_ok=True)
    if finetune:
        name = f'{name}_finetune'
    # else:
    #     name = name
    filename = f'{save_path}/{name}_last.pth.tar'
    torch.save(state, filename, _use_new_zipfile_serialization=False)
    if is_best:
        shutil.copyfile(filename, f'{save_path}/{name}_best.pth.tar')

def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]
    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class SmoothCrossEntropy(nn.Module):
    def __init__(self, alpha=0.1):
        super(SmoothCrossEntropy, self).__init__()
        self.alpha = alpha

    def forward(self, logits, labels):
        if self.alpha == 0:
            loss = F.cross_entropy(logits, labels)
        else:
            num_classes = logits.shape[-1]
            alpha_div_k = self.alpha / num_classes
            target_probs = F.one_hot(labels, num_classes=num_classes).float() * (1. - self.alpha) + alpha_div_k
            loss = (-(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1)).mean()
        return loss


class SmoothCrossEntropyV2(nn.Module):
    """NLL loss with label smoothing."""
    def __init__(self, label_smoothing=0.1):
        super().__init__()
        assert label_smoothing < 1.0
        self.smoothing = label_smoothing
        self.confidence = 1. - label_smoothing

    def forward(self, x, target):
        if self.smoothing == 0:
            loss = F.cross_entropy(x, target)
        else:
            logprobs = F.log_softmax(x, dim=-1)
            nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
            nll_loss = nll_loss.squeeze(1)
            smooth_loss = -logprobs.mean(dim=-1)
            loss = (self.confidence * nll_loss + self.smoothing * smooth_loss).mean()
        return loss


class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [None]:
# @title ModelEMA
# expopnential moving average, smoothen model parameters
# https://github.com/kekmodel/MPL-pytorch/blob/main/models.py
import torch
import torch.nn as nn
from copy import deepcopy

class ModelEMA(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def forward(self, input):
        return self.module(input)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.parameters(), model.parameters()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))
            for ema_v, model_v in zip(self.module.buffers(), model.buffers()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(model_v)

    def update_parameters(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def state_dict(self):
        return self.module.state_dict()

    def load_state_dict(self, state_dict):
        self.module.load_state_dict(state_dict)



In [None]:
# @title main
# https://github.com/kekmodel/MPL-pytorch/blob/main/main.py

import math
import os
import random
import time

import numpy as np
import torch
from torch.cuda import amp
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm


# temperature = 1 # default 1 / mainargs 0.7
ema = 0.995 # default 0 / mainargs 0.995
# local_rank = -1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def set_seed():
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_wait_steps=0, num_cycles=0.5, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_wait_steps:
            return 0.0
        if current_step < num_warmup_steps + num_wait_steps:
            return float(current_step) / float(max(1, num_warmup_steps + num_wait_steps))
        progress = float(current_step - num_warmup_steps - num_wait_steps) / \
            float(max(1, num_training_steps - num_warmup_steps - num_wait_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
    return LambdaLR(optimizer, lr_lambda, last_epoch)

def get_lr(optimizer):
    return optimizer.param_groups[0]['lr']


def train_loop(labeled_loader, unlabeled_loader, test_loader, finetune_dataset,
               teacher_model, student_model, avg_student_model, criterion,
               t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler):
    save_path = './checkpoint'
    name = 'model'

    labeled_iter = iter(labeled_loader)
    unlabeled_iter = iter(unlabeled_loader)

    # for author's code formula
    # moving_dot_product = torch.empty(1).to(device)
    # limit = 3.0**(0.5)  # 3 = 6 / (f_in + f_out)
    # nn.init.uniform_(moving_dot_product, -limit, limit)

    eval_step = 1000
    start_step=0
    for step in range(start_step, total_steps):
        if step % eval_step == 0:
            batch_time = AverageMeter()
            data_time = AverageMeter()
            s_losses = AverageMeter()
            t_losses = AverageMeter()
            t_losses_l = AverageMeter()
            t_losses_u = AverageMeter()
            t_losses_mpl = AverageMeter()
            mean_mask = AverageMeter()

        teacher_model.train()
        student_model.train()
        end = time.time()

        try:
            images_l, targets = next(labeled_iter)
        except:
            labeled_iter = iter(labeled_loader)
            images_l, targets = next(labeled_iter)

        try:
            # (images_uw, images_us), _ = next(unlabeled_iter)
            images_uw, _ = next(unlabeled_iter)
            images_us = trs(images_uw)
        except:
            unlabeled_iter = iter(unlabeled_loader)
            # (images_uw, images_us), _ = next(unlabeled_iter)
            images_uw, _ = next(unlabeled_iter)
            images_us = trs(images_uw)

        data_time.update(time.time() - end)

        images_l = images_l.to(device)
        images_uw = images_uw.to(device)
        images_us = images_us.to(device)
        targets = targets.to(device)
        with amp.autocast():
            batch_size = images_l.shape[0]
            # print(images_l.shape, images_uw.shape, images_us.shape) # [64, 3, 32, 32]. [448, 3, 32, 32], [448, 3, 32, 32]
            t_images = torch.cat((images_l, images_uw, images_us))
            t_logits = teacher_model(t_images)
            t_logits_l = t_logits[:batch_size]
            t_logits_uw, t_logits_us = t_logits[batch_size:].chunk(2)
            del t_logits

            t_loss_l = criterion(t_logits_l, targets)

            temperature = 1 # default 1 / mainargs 0.7
            soft_pseudo_label = torch.softmax(t_logits_uw.detach() / temperature, dim=-1)
            max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)

            threshold = 0.95 # default 0.95 / mainargs 0.6
            mask = max_probs.ge(threshold).float()
            t_loss_u = torch.mean(-(soft_pseudo_label * torch.log_softmax(t_logits_us, dim=-1)).sum(dim=-1) * mask)
            lambda_u = 8 # default 1 / mainargs 8 coefficient of unlabeled loss
            uda_steps = 10 # default 1 / mainargs 5000 warmup steps of lambda-u
            weight_u = lambda_u * min(1., (step + 1) / uda_steps)
            t_loss_uda = t_loss_l + weight_u * t_loss_u

            s_images = torch.cat((images_l, images_us))
            s_logits = student_model(s_images)
            s_logits_l = s_logits[:batch_size]
            s_logits_us = s_logits[batch_size:]
            del s_logits

            s_loss_l_old = F.cross_entropy(s_logits_l.detach(), targets)

            # print("s_logits_us, hard_pseudo_label: ", s_logits_us.shape, hard_pseudo_label.shape) # [448, 10] [224]
            s_loss = criterion(s_logits_us, hard_pseudo_label)

        s_scaler.scale(s_loss).backward()

        # if grad_clip > 0:
        s_scaler.unscale_(s_optimizer)
        nn.utils.clip_grad_norm_(student_model.parameters(), 1e9)

        s_scaler.step(s_optimizer)
        s_scaler.update()
        s_scheduler.step()

        if ema > 0: avg_student_model.update_parameters(student_model)

        with amp.autocast():
            with torch.no_grad():
                s_logits_l = student_model(images_l)
            s_loss_l_new = F.cross_entropy(s_logits_l.detach(), targets)

            dot_product = s_loss_l_old - s_loss_l_new # theoretically correct formula (https://github.com/kekmodel/MPL-pytorch/issues/6)
            # dot_product = s_loss_l_new - s_loss_l_old # author's code formula
            # # moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01
            # # dot_product = dot_product - moving_dot_product

            _, hard_pseudo_label = torch.max(t_logits_us.detach(), dim=-1)
            t_loss_mpl = dot_product * F.cross_entropy(t_logits_us, hard_pseudo_label)
            # test
            # t_loss_mpl = torch.tensor(0.).to(device)
            t_loss = t_loss_uda + t_loss_mpl

        t_scaler.scale(t_loss).backward()

        # if grad_clip > 0:
        t_scaler.unscale_(t_optimizer)
        nn.utils.clip_grad_norm_(teacher_model.parameters(), 1e9)

        t_scaler.step(t_optimizer)
        t_scaler.update()
        t_scheduler.step()

        teacher_model.zero_grad()
        student_model.zero_grad()

        s_losses.update(s_loss.item())
        t_losses.update(t_loss.item())
        t_losses_l.update(t_loss_l.item())
        t_losses_u.update(t_loss_u.item())
        t_losses_mpl.update(t_loss_mpl.item())
        mean_mask.update(mask.mean().item())

        batch_time.update(time.time() - end)
        # pbar.set_description(
        #     f"Train Iter: {step+1:3}/{total_steps:3}. "
        #     f"LR: {get_lr(s_optimizer):.4f}. Data: {data_time.avg:.2f}s. "
        #     f"Batch: {batch_time.avg:.2f}s. S_Loss: {s_losses.avg:.4f}. "
        #     f"T_Loss: {t_losses.avg:.4f}. Mask: {mean_mask.avg:.4f}. ")
        # pbar.update()

        # args.num_eval = step // eval_step
        if (step + 1) % eval_step == 0:
            # print(s_losses.avg, t_losses.avg, t_losses_l.avg, t_losses_u.avg, t_losses_mpl.avg, mean_mask.avg)

            test_model = avg_student_model if avg_student_model is not None else student_model
            test_loss, top1, top5 = evaluate(test_loader, test_model, criterion)


            is_best = top1 > best_top1
            if is_best:
                best_top1 = top1
                best_top5 = top5

            # './checkpoint/model_best.pth.tar
            save_checkpoint(name,save_path,{
                'step': step + 1,
                'teacher_state_dict': teacher_model.state_dict(),
                'student_state_dict': student_model.state_dict(),
                'avg_state_dict': avg_student_model.state_dict() if avg_student_model is not None else None,
                'best_top1': best_top1,
                'best_top5': best_top5,
                'teacher_optimizer': t_optimizer.state_dict(),
                'student_optimizer': s_optimizer.state_dict(),
                'teacher_scheduler': t_scheduler.state_dict(),
                'student_scheduler': s_scheduler.state_dict(),
                'teacher_scaler': t_scaler.state_dict(),
                'student_scaler': s_scaler.state_dict(),
            }, is_best)

    # finetune
    ckpt_name = f'{save_path}/{name}_best.pth.tar'
    gpu=0
    loc = f'cuda:{gpu}'
    checkpoint = torch.load(ckpt_name, map_location=loc)
    if checkpoint['avg_state_dict'] is not None:
        model_load_state_dict(student_model, checkpoint['avg_state_dict'])
    else:
        model_load_state_dict(student_model, checkpoint['student_state_dict'])
    finetune(finetune_dataset, test_loader, student_model, criterion)
    return


def evaluate(test_loader, model, criterion):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()
    test_iter = tqdm(test_loader, disable=False)
    with torch.no_grad():
        end = time.time()
        for step, (images, targets) in enumerate(test_iter):
            data_time.update(time.time() - end)
            batch_size = images.shape[0]
            images = images.to(device)
            targets = targets.to(device)
            with amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)

            acc1, acc5 = accuracy(outputs, targets, (1, 5))
            losses.update(loss.item(), batch_size)
            top1.update(acc1[0], batch_size)
            top5.update(acc5[0], batch_size)
            batch_time.update(time.time() - end)
            end = time.time()

        test_iter.close()
        return losses.avg, top1.avg, top5.avg


def finetune(finetune_dataset, test_loader, model, criterion):
    model.drop = nn.Identity()
    train_sampler = RandomSampler
    labeled_loader = DataLoader(finetune_dataset, batch_size=512, num_workers=4, pin_memory=True)
    optimizer = optim.SGD(model.parameters(), lr=3e-5, momentum=0.9, weight_decay=0, nesterov=True)
    scaler = amp.GradScaler()

    for epoch in range(1): #625

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        model.train()
        end = time.time()
        labeled_iter = tqdm(labeled_loader, disable=False)
        for step, (images, targets) in enumerate(labeled_iter):
            data_time.update(time.time() - end)
            batch_size = images.shape[0]
            images = images.to(device)
            targets = targets.to(device)
            with amp.autocast():
                model.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            losses.update(loss.item(), batch_size)
            batch_time.update(time.time() - end)
            labeled_iter.set_description(
                f"Finetune Epoch: {epoch+1:2}/{625:2}. Data: {data_time.avg:.2f}s. "
                f"Batch: {batch_time.avg:.2f}s. Loss: {losses.avg:.4f}. ")
        labeled_iter.close()

        # print(losses.avg)
        test_loss, top1, top5 = evaluate(test_loader, model, criterion)

        is_best = top1 > best_top1
        if is_best:
            best_top1 = top1
            best_top5 = top5

        save_checkpoint({
            'step': step + 1,
            'best_top1': best_top1,
            'best_top5': best_top5,
            'student_state_dict': model.state_dict(),
            'avg_state_dict': None,
            'student_optimizer': optimizer.state_dict(),
        }, is_best, finetune=True)
    return



# labeled_dataset, unlabeled_dataset, test_dataset, finetune_dataset = DATASET_GETTERS[dataset](args)

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_data = datasets.CIFAR10(root="data", train=True, download=True,transform=transform)
# test_data = datasets.CIFAR10(root="data", train=False, download=True,transform=transform)
labeled_dataset, unlabeled_dataset = torch.utils.data.random_split(train_data, [.1,.9])
test_dataset = datasets.CIFAR10(root="data", train=False, download=True,transform=transform)
finetune_dataset = labeled_dataset


batch_size = 64 # default 64/ mainargs128
train_sampler = RandomSampler
labeled_loader = DataLoader(labeled_dataset, sampler=train_sampler(labeled_dataset), batch_size=batch_size, num_workers=4, drop_last=True)
unlabeled_loader = DataLoader(unlabeled_dataset, sampler=train_sampler(unlabeled_dataset), batch_size=batch_size * 7, num_workers=4, drop_last=True) # mu=7 ,coefficient of unlabeled batch size
test_loader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=batch_size, num_workers=4)

num_classes = 10
# if dataset == "cifar10": depth, widen_factor = 28, 2
# elif dataset == 'cifar100': depth, widen_factor = 28, 8
# teacher_model = WideResNet(num_classes=num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=0.2)
# student_model = WideResNet(num_classes=num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=0.2)


from torchvision import models
def get_resnet():
    model = models.resnet152(weights='DEFAULT') # 18 34 50 101 152
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential( # og (fc): Linear(in_features=2048, out_features=1000, bias=True)
        nn.Linear(num_ftrs, num_classes, bias=False),
        nn.Softmax(dim=1),
        )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # model = model.to(device)
    model = torch.compile(model.to(device))
    return model

teacher_model = get_resnet()
student_model = get_resnet()


# teacher_model.to(device)
# student_model.to(device)
avg_student_model = None
if ema > 0: avg_student_model = ModelEMA(student_model, ema)


criterion = create_loss_fn()

no_decay = ['bn']
weight_decay = 5e-4 # default 0 / mainargs 5e-4
teacher_parameters = [{'params': [p for n, p in teacher_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
    {'params': [p for n, p in teacher_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
student_parameters = [{'params': [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
    {'params': [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

# lr default 0.01/ mainargs 0.05
t_optimizer = optim.SGD(teacher_parameters, lr=0.05, momentum=0.9, nesterov=True)
s_optimizer = optim.SGD(student_parameters, lr=0.05, momentum=0.9, nesterov=True)

total_steps=30 # 300000
warmup_steps = 100 # default 0 / mainargs 5000
t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, warmup_steps, total_steps)

student_wait_steps = 100 # default 0 / mainargs 3000
s_scheduler = get_cosine_schedule_with_warmup(s_optimizer, warmup_steps, total_steps, student_wait_steps)

t_scaler = amp.GradScaler()
s_scaler = amp.GradScaler()

# # optionally resume from a checkpoint
# if resume:
#     if os.path.isfile(resume):
#         logger.info(f"=> loading checkpoint '{resume}'")
#         loc = f'cuda:{gpu}'
#         checkpoint = torch.load(resume, map_location=loc)
#         best_top1 = checkpoint['best_top1'].to(torch.device('cpu'))
#         best_top5 = checkpoint['best_top5'].to(torch.device('cpu'))
#         if not (evaluate or finetune):
#             start_step = checkpoint['step']
#             t_optimizer.load_state_dict(checkpoint['teacher_optimizer'])
#             s_optimizer.load_state_dict(checkpoint['student_optimizer'])
#             t_scheduler.load_state_dict(checkpoint['teacher_scheduler'])
#             s_scheduler.load_state_dict(checkpoint['student_scheduler'])
#             t_scaler.load_state_dict(checkpoint['teacher_scaler'])
#             s_scaler.load_state_dict(checkpoint['student_scaler'])
#             model_load_state_dict(teacher_model, checkpoint['teacher_state_dict'])
#             if avg_student_model is not None:
#                 model_load_state_dict(avg_student_model, checkpoint['avg_state_dict'])
#         else:
#             if checkpoint['avg_state_dict'] is not None:
#                 model_load_state_dict(student_model, checkpoint['avg_state_dict'])
#             else:
#                 model_load_state_dict(student_model, checkpoint['student_state_dict'])

    #     logger.info(f"=> loaded checkpoint '{resume}' (step {checkpoint['step']})")
    # else:
    #     logger.info(f"=> no checkpoint found at '{resume}'")


# finetune(finetune_dataset, test_loader, student_model, criterion)
# evaluate(test_loader, student_model, criterion)

teacher_model.zero_grad()
student_model.zero_grad()
train_loop(labeled_loader, unlabeled_loader, test_loader, finetune_dataset,
            teacher_model, student_model, avg_student_model, criterion,
            t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler)





Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/1000 [00:00<?, ?it/s]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter: 300/300. LR: 0.0000. Data: 1.48s. Batch: 3.07s. S_Loss: 2.2101. T_Loss: 2.0056. Mask: 0.0000. :  30%|███       | 300/1000 [41:06<1:35:55,  8.22s/it]


torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   2/ 30. LR: 0.0000. Data: 1.87s. Batch: 119.26s. S_Loss: 2.3031. T_Loss: 2.3113. Mask: 0.0000. :   0%|          | 2/1000 [03:58<27:20:34, 98.63s/it] 

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   3/ 30. LR: 0.0000. Data: 1.89s. Batch: 80.49s. S_Loss: 2.3026. T_Loss: 2.3008. Mask: 0.0000. :   0%|          | 3/1000 [04:01<15:12:59, 54.94s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   4/ 30. LR: 0.0000. Data: 1.76s. Batch: 60.93s. S_Loss: 2.3027. T_Loss: 2.3057. Mask: 0.0000. :   0%|          | 4/1000 [04:03<9:26:45, 34.14s/it] 

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   5/ 30. LR: 0.0000. Data: 1.67s. Batch: 49.18s. S_Loss: 2.3024. T_Loss: 2.3064. Mask: 0.0000. :   0%|          | 5/1000 [04:05<6:15:13, 22.63s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   6/ 30. LR: 0.0000. Data: 1.61s. Batch: 41.35s. S_Loss: 2.3025. T_Loss: 2.3082. Mask: 0.0000. :   1%|          | 6/1000 [04:08<4:19:49, 15.68s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   7/ 30. LR: 0.0000. Data: 1.56s. Batch: 35.76s. S_Loss: 2.3028. T_Loss: 2.3069. Mask: 0.0000. :   1%|          | 7/1000 [04:10<3:06:34, 11.27s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   8/ 30. LR: 0.0000. Data: 1.53s. Batch: 31.57s. S_Loss: 2.3026. T_Loss: 2.3057. Mask: 0.0000. :   1%|          | 8/1000 [04:12<2:18:52,  8.40s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:   9/ 30. LR: 0.0000. Data: 1.58s. Batch: 28.39s. S_Loss: 2.3027. T_Loss: 2.2997. Mask: 0.0000. :   1%|          | 9/1000 [04:15<1:50:46,  6.71s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  10/ 30. LR: 0.0000. Data: 1.56s. Batch: 25.77s. S_Loss: 2.3025. T_Loss: 2.3019. Mask: 0.0000. :   1%|          | 10/1000 [04:17<1:27:51,  5.33s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  11/ 30. LR: 0.0000. Data: 1.53s. Batch: 23.63s. S_Loss: 2.3024. T_Loss: 2.3018. Mask: 0.0000. :   1%|          | 11/1000 [04:20<1:12:06,  4.37s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  12/ 30. LR: 0.0000. Data: 1.51s. Batch: 21.84s. S_Loss: 2.3027. T_Loss: 2.3012. Mask: 0.0000. :   1%|          | 12/1000 [04:22<1:00:57,  3.70s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  13/ 30. LR: 0.0000. Data: 1.49s. Batch: 20.33s. S_Loss: 2.3027. T_Loss: 2.3008. Mask: 0.0000. :   1%|▏         | 13/1000 [04:24<53:14,  3.24s/it]  

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  14/ 30. LR: 0.0000. Data: 1.50s. Batch: 19.05s. S_Loss: 2.3030. T_Loss: 2.3021. Mask: 0.0000. :   1%|▏         | 14/1000 [04:26<49:27,  3.01s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  15/ 30. LR: 0.0000. Data: 1.52s. Batch: 17.97s. S_Loss: 2.3030. T_Loss: 2.3014. Mask: 0.0000. :   2%|▏         | 15/1000 [04:29<48:28,  2.95s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  16/ 30. LR: 0.0000. Data: 1.51s. Batch: 16.98s. S_Loss: 2.3029. T_Loss: 2.3015. Mask: 0.0000. :   2%|▏         | 16/1000 [04:31<44:43,  2.73s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  17/ 30. LR: 0.0000. Data: 1.49s. Batch: 16.11s. S_Loss: 2.3030. T_Loss: 2.3031. Mask: 0.0000. :   2%|▏         | 17/1000 [04:34<42:06,  2.57s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  18/ 30. LR: 0.0000. Data: 1.48s. Batch: 15.34s. S_Loss: 2.3031. T_Loss: 2.3046. Mask: 0.0000. :   2%|▏         | 18/1000 [04:36<40:05,  2.45s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  19/ 30. LR: 0.0000. Data: 1.47s. Batch: 14.64s. S_Loss: 2.3031. T_Loss: 2.3029. Mask: 0.0000. :   2%|▏         | 19/1000 [04:38<38:47,  2.37s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  20/ 30. LR: 0.0000. Data: 1.49s. Batch: 14.06s. S_Loss: 2.3030. T_Loss: 2.3028. Mask: 0.0000. :   2%|▏         | 20/1000 [04:41<41:28,  2.54s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  21/ 30. LR: 0.0000. Data: 1.49s. Batch: 13.50s. S_Loss: 2.3029. T_Loss: 2.3011. Mask: 0.0000. :   2%|▏         | 21/1000 [04:43<40:29,  2.48s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  22/ 30. LR: 0.0000. Data: 1.48s. Batch: 12.98s. S_Loss: 2.3029. T_Loss: 2.2993. Mask: 0.0000. :   2%|▏         | 22/1000 [04:45<38:52,  2.39s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  23/ 30. LR: 0.0000. Data: 1.47s. Batch: 12.52s. S_Loss: 2.3029. T_Loss: 2.2987. Mask: 0.0000. :   2%|▏         | 23/1000 [04:48<37:58,  2.33s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  24/ 30. LR: 0.0000. Data: 1.47s. Batch: 12.09s. S_Loss: 2.3028. T_Loss: 2.2988. Mask: 0.0000. :   2%|▏         | 24/1000 [04:50<37:17,  2.29s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  25/ 30. LR: 0.0000. Data: 1.46s. Batch: 11.69s. S_Loss: 2.3028. T_Loss: 2.2985. Mask: 0.0000. :   2%|▎         | 25/1000 [04:52<37:03,  2.28s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  26/ 30. LR: 0.0000. Data: 1.48s. Batch: 11.35s. S_Loss: 2.3027. T_Loss: 2.2991. Mask: 0.0000. :   3%|▎         | 26/1000 [04:55<40:16,  2.48s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  27/ 30. LR: 0.0000. Data: 1.47s. Batch: 11.01s. S_Loss: 2.3028. T_Loss: 2.2993. Mask: 0.0000. :   3%|▎         | 27/1000 [04:57<38:48,  2.39s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  28/ 30. LR: 0.0000. Data: 1.46s. Batch: 10.70s. S_Loss: 2.3027. T_Loss: 2.2985. Mask: 0.0000. :   3%|▎         | 28/1000 [04:59<37:48,  2.33s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  29/ 30. LR: 0.0000. Data: 1.46s. Batch: 10.41s. S_Loss: 2.3028. T_Loss: 2.2973. Mask: 0.0000. :   3%|▎         | 29/1000 [05:02<36:59,  2.29s/it]

torch.Size([64, 3, 32, 32]) torch.Size([448, 3, 32, 32]) torch.Size([448, 3, 32, 32])


Train Iter:  30/ 30. LR: 0.0000. Data: 1.45s. Batch: 10.13s. S_Loss: 2.3029. T_Loss: 2.2960. Mask: 0.0000. :   3%|▎         | 30/1000 [05:04<36:30,  2.26s/it]

FileNotFoundError: ignored