<a href="https://colab.research.google.com/github/eisbetterthanpi/vision/blob/main/Meta_Pseudo_Labels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Meta Pseudo Labels mar 2021 https://arxiv.org/pdf/2003.10580v4.pdf
# https://github.com/kekmodel/MPL-pytorch
'''
Meta Pseudo Labels by google research, implimentation by kekmodel
modified to save a significant amount of memory

other methods like ordinal regression , ensembling, test time augmentation con only give a slight increase in accuracy
this is the most promissing method , with the added benefit of producing a more robust classifier

will require a huge unlabelled dataset of houses and a very clean dataset of labelled houses

cant seem to make it train properly

'''

In [None]:
# @title download
# # # google images unlabeled
# !gdown 1ncx2DJ-GXqrQd6nL5UEmj6GLT4w-9qYs -O house.zip
# !unzip /content/house.zip -d /
# !rm -R /content/house/.ipynb_checkpoints

# # # 70k+gmap
# !gdown 1-CZp7TbhJLeRQpbKQCyT8ofGg89Yt137 -O gsv.zip
# !unzip /content/gsv.zip -d /
# !rm -R /content/gsv70kg/.ipynb_checkpoints
# !rm -R /content/gsv70kg/01/.ipynb_checkpoints
# !rm -R /content/gsv70kg/02/.ipynb_checkpoints
# !rm -R /content/gsv70kg/03/.ipynb_checkpoints
# !rm -R /content/gsv70kg/04/.ipynb_checkpoints
# !rm -R /content/gsv70kg/05/.ipynb_checkpoints
# !rm -R /content/gsv70kg/06/.ipynb_checkpoints


In [None]:
# @title data
import torch
import numpy as np
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

labeled_dir='/content/gsv70kg'

labeled_data = datasets.ImageFolder(labeled_dir, transform=transform)
torch.manual_seed(0)

# import random
# random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)


train_data, test_data = torch.utils.data.random_split(labeled_data, [.9,.1])
# train_data, _ = torch.utils.data.random_split(train_data, [.01,.99])
# test_data, _ = torch.utils.data.random_split(test_data, [.01,.99])
finetune_dataset = train_data

unlabel_dir='/content/house'
# unlabel_data = datasets.ImageFolder(unlabel_dir, transform=transform)

import os
from PIL import Image
class Datasetme(torch.utils.data.Dataset):
    def __init__(self, dir, transform=None):
        self.dir = dir
        self.data = os.listdir(dir)
        self.transform = transform
    def __getitem__(self, index):
        img_file = self.data[index]
        img_file = os.path.join(self.dir, img_file)
        image = Image.open(img_file).convert("RGB")
        if self.transform: image = self.transform(image)
        return image
    def __len__(self): return len(self.data)

unlabel_data = Datasetme(unlabel_dir, transform=transform)
# unlabel_data, _ = torch.utils.data.random_split(unlabel_data, [.01,.99])


batch_size = 64 # 16 is max for res152; default 64/ mainargs128
grad_acc = 1

# res152 batch16 gradacc4

num_batches=int(np.ceil(len(train_data)/batch_size))

# train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, pin_memory=True)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, pin_memory=True)

train_sampler = RandomSampler
# labeled_loader = DataLoader(labeled_data, sampler=train_sampler(labeled_data), batch_size=batch_size, num_workers=4, drop_last=True)
# unlabeled_loader = DataLoader(unlabel_data, sampler=train_sampler(unlabel_data), batch_size=batch_size * 7, num_workers=4, drop_last=True) # mu=7 ,coefficient of unlabeled batch size
# test_loader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size, num_workers=4)

labeled_loader = DataLoader(labeled_data, sampler=train_sampler(labeled_data), batch_size=batch_size*grad_acc, drop_last=True)
mu=7 # coefficient of unlabeled batch size
unlabeled_loader = DataLoader(unlabel_data, sampler=train_sampler(unlabel_data), batch_size=batch_size*grad_acc * mu, drop_last=True)
test_loader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size*grad_acc)

del labeled_data, train_data, test_data, unlabel_data


In [None]:
# @title torch augment
# https://github.com/facebookresearch/vicreg/blob/main/augmentations.py
import torch
import torchvision.transforms as transforms

class TrainTransform(object):
    def __init__(self):
        # self.transform = transforms.RandomApply([transforms.Compose([
        self.transform = transforms.Compose([
                transforms.RandomPerspective(distortion_scale=0.3, p=0.5), # me
                transforms.RandomResizedCrop((400,640), scale=(0.7, 1.0), ratio=(0.8, 1.25), interpolation=transforms.InterpolationMode.BICUBIC),
                # transforms.RandomResizedCrop((32,32), scale=(0.7, 1.0), ratio=(0.8, 1.25), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5), # 0.5
                transforms.Lambda(lambda x : torch.clamp(x, 0., 1.)), # clamp else ColorJitter will return nan https://discuss.pytorch.org/t/input-is-nan-after-transformation/125455/6
                transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8,), # brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8
                transforms.RandomGrayscale(p=0.2), # 0.2
                # # transforms.RandomChoice(transforms.ColorJitter , transforms.RandomGrayscale(p=1.)
                transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),], p=1.0),
                # # transforms.RandomSolarize(threshold=130, p=0.5)
                transforms.RandomErasing(p=1., scale=(0.1, 0.11), ratio=(1,1), value=(0.485, 0.456, 0.406)),
                # transforms.ToTensor(), # ToTensored at dataset level, no need to ToTensor again
                # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalised at dataset level. default 0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225
                ])
            # ], p=1.)

    def __call__(self, sample):
        dims = len(sample.shape)
        if dims==3: x1 = self.transform(sample) # same transforms per minibatch
        elif dims==4: x1 = transforms.Lambda(lambda x: torch.stack([self.transform(x_) for x_ in x]))(sample) # diff transforms per img in minibatch
        # x1 = self.transform(sample)
        return x1

trs=TrainTransform()


In [None]:
# @title utils
# https://github.com/kekmodel/MPL-pytorch/blob/main/utils.py
import torch
from torch import nn
from torch.nn import functional as F

def create_loss_fn():
    label_smoothing = 0 # default 0 / mainargs 0.15
    # if label_smoothing > 0:
    #     criterion = SmoothCrossEntropyV2(alpha=label_smoothing)
    # else:
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    return criterion.to(device)

from collections import OrderedDict
def module_load_state_dict(model, state_dict):
    try:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    except:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = f'module.{k}'  # add `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

def model_load_state_dict(model, state_dict):
    try: model.load_state_dict(state_dict)
    except: module_load_state_dict(model, state_dict)


def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]
    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class SmoothCrossEntropy(nn.Module):
    def __init__(self, alpha=0.1):
        super(SmoothCrossEntropy, self).__init__()
        self.alpha = alpha

    def forward(self, logits, labels):
        if self.alpha == 0:
            loss = F.cross_entropy(logits, labels)
        else:
            num_classes = logits.shape[-1]
            alpha_div_k = self.alpha / num_classes
            target_probs = F.one_hot(labels, num_classes=num_classes).float() * (1. - self.alpha) + alpha_div_k
            loss = (-(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1)).mean()
        return loss

class SmoothCrossEntropyV2(nn.Module):
    """NLL loss with label smoothing."""
    def __init__(self, label_smoothing=0.1):
        super().__init__()
        assert label_smoothing < 1.0
        self.smoothing = label_smoothing
        self.confidence = 1. - label_smoothing

    def forward(self, x, target):
        if self.smoothing == 0:
            loss = F.cross_entropy(x, target)
        else:
            logprobs = F.log_softmax(x, dim=-1)
            nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
            nll_loss = nll_loss.squeeze(1)
            smooth_loss = -logprobs.mean(dim=-1)
            loss = (self.confidence * nll_loss + self.smoothing * smooth_loss).mean()
        return loss


# from main
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_wait_steps=0, num_cycles=0.5, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_wait_steps:
            return 0.0
        if current_step < num_warmup_steps + num_wait_steps:
            return float(current_step) / float(max(1, num_warmup_steps + num_wait_steps))
        progress = float(current_step - num_warmup_steps - num_wait_steps) / float(max(1, num_training_steps - num_warmup_steps - num_wait_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
    return LambdaLR(optimizer, lr_lambda, last_epoch)



In [None]:
# @title plot lr scheduler
# import matplotlib
# import matplotlib.pyplot as plt
# # matplotlib.rcParams['figure.dpi'] = 300
# # plt.axis('off')

# from torch import optim
# t_optimizer = optim.SGD(teacher_parameters, lr=1.5e-4, momentum=0.9, nesterov=True)
# s_optimizer = optim.SGD(student_parameters, lr=3e-4, momentum=0.9, nesterov=True)

# total_steps=1000 # 300000
# warmup_steps = 100 # default 0 / mainargs 5000
# # t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, warmup_steps, total_steps)
# t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, warmup_steps, total_steps,10)
# student_wait_steps = 50 # default 0 / mainargs 3000
# s_scheduler = get_cosine_schedule_with_warmup(s_optimizer, warmup_steps, total_steps, student_wait_steps)


# tlr_lst=[]
# slr_lst=[]
# for t in range(total_steps):
#     tlr=t_optimizer.param_groups[0]["lr"]
#     tlr_lst.append(tlr)
#     slr=s_optimizer.param_groups[0]["lr"]
#     slr_lst.append(slr)
#     t_scheduler.step()
#     s_scheduler.step()
# plt.plot(tlr_lst)
# plt.plot(slr_lst)

# plt.show()


In [None]:
# @title ModelEMA
# exponential moving average, smoothen model parameters
# https://github.com/kekmodel/MPL-pytorch/blob/main/models.py
import torch
import torch.nn as nn
from copy import deepcopy

class ModelEMA(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def forward(self, input):
        return self.module(input)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.parameters(), model.parameters()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))
            for ema_v, model_v in zip(self.module.buffers(), model.buffers()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(model_v)

    def update_parameters(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def state_dict(self):
        return self.module.state_dict()

    def load_state_dict(self, state_dict):
        self.module.load_state_dict(state_dict)



In [None]:
# @title big teacher
import torch
import torch.nn as nn
from torchvision import models
# !gdown 1ysJfdsvwMiWbCdkvFHwNqAUnJTtm6KbT -O res152adamw71.pth # ty
# !gdown 1VaPxGoaLjmt7K9VHi0FWbJ5efEZTLhwd -O res18teacher.pth # A
# !pip install bitsandbytes

model = models.resnet18(weights='DEFAULT') # 18 34 50 101 152
num_ftrs = model.fc.in_features
model.fc = nn.Sequential( # og (fc): Linear(in_features=2048, out_features=1000, bias=True)
    # nn.Linear(num_ftrs, num_classes, bias=False),
    # nn.Softmax(dim=1),
    )

# model.load_state_dict(torch.load('/content/bigTeacher.pth'))
# _, modelsd, _,_ = torch.load('/content/bigTeacher.pth').values()
# _, modelsd, _,_ = torch.load('/content/res152adamw71.pth').values()
_, modelsd, _,_ = torch.load('/content/res18teacher.pth').values()
model.load_state_dict(modelsd, strict=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# model = torch.compile(model.to(device)) # compiling teacher leads to significantly higher vram usage


model.eval()
print('uh')


# torch.save(model.state_dict(), '/content/model.pth')
# modelsd = torch.load('/content/model.pth')
# model.load_state_dict(modelsd, strict=False)
# # model = torch.compile(model.to(device))


uh


In [None]:
# @title model

num_classes = 6
# if dataset == "cifar10": depth, widen_factor = 28, 2
# elif dataset == 'cifar100': depth, widen_factor = 28, 8
# teacher_model = WideResNet(num_classes=num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=0.2)
# student_model = WideResNet(num_classes=num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=0.2)
import torch
import torch.nn as nn
from torchvision import models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_resnet():
    model = models.resnet18(weights='DEFAULT') # 18 34 50 101 152
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential( # og (fc): Linear(in_features=2048, out_features=1000, bias=True)
        # nn.Linear(num_ftrs, num_classes),
        nn.Linear(num_ftrs, num_classes, bias=False),
        nn.Softmax(dim=1),
        )
    # model = model.to(device)
    model = torch.compile(model.to(device))
    return model

class Small(nn.Module):
    def __init__(self, embed_dim, output_dim):
        super(Small, self).__init__()
        hidden_size=512
        self.lin = nn.Sequential(
            nn.Linear(embed_dim, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            # nn.Linear(hidden_size, output_dim),
            nn.Linear(hidden_size, output_dim, bias=False),
            nn.Softmax(dim=1), # teacher need output logits!, not softmax?
        )
    def forward(self, x):
        logits = self.lin(x)
        return logits




# @title ensemble
import torch
import torch.nn as nn

class Ensemble(nn.Module):
    def __init__(self, embed_dim, output_dim):
        super().__init__()
        self.output_dim = output_dim # 6
        self.embed_dim = embed_dim
        h_dim = 512
        self.fwd = nn.Sequential(
            nn.Linear(self.embed_dim, h_dim), nn.ReLU(),
            # nn.Linear(h_dim, h_dim), nn.ReLU(),
            # Block(h_dim, h_dim, 0.5),
            Block(h_dim, h_dim),
            Block(h_dim, h_dim),
            Block(h_dim, h_dim),
            nn.Linear(h_dim, self.output_dim),
            # nn.Linear(h_dim, self.output_dim, bias=False),
            # nn.Softmax(dim=1),
            )
    def forward(self, x):
        out = self.fwd(x)
        return out

class Block(nn.Module):
    def __init__(self, in_dim, out_dim, drop=None):
        super().__init__()
        if drop: self.fwd = nn.Sequential(nn.BatchNorm1d(in_dim), nn.Dropout(drop), nn.Linear(in_dim, out_dim), nn.ReLU(),)
        else: self.fwd = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
    def forward(self, x):
        return x + self.fwd(x)

# teacher_model = Ensemble(2048, 6).to(device)
# teacher_model = torch.compile(Ensemble(2048, 6).to(device))




# teacher_model = Small(num_ftrs,6).to(device)
teacher_model = torch.compile(Small(num_ftrs,6).to(device))
# teacher_model = get_resnet()






student_model = get_resnet()
# _, modelsd, _,_ = torch.load('/content/res18teacher.pth').values()
# student_model.load_state_dict(modelsd, strict=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
student_model = student_model.to(device)
# student_model = torch.compile(student_model.to(device)) #


avg_student_model = None
ema = 0.995 # default 0 / mainargs 0.995

if ema > 0: avg_student_model = ModelEMA(student_model, ema)

no_decay = ['bn']
weight_decay = 5e-4 # default 0 / mainargs 5e-4
teacher_parameters = [{'params': [p for n, p in teacher_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
    {'params': [p for n, p in teacher_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
student_parameters = [{'params': [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
    {'params': [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

# teacher_model.zero_grad()
# student_model.zero_grad()





In [None]:
# @title try
# import torch
# import torch.nn as nn
# from torchvision import models
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # # 1.9, 1.7, 1.2
# im = torch.rand(32,3,400,680,device=device) #
# # print(32*3*400*680*32/8)
# # print(im.element_size() * im.nelement())

# def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
# # print(count_parameters(model))


# model = models.resnet152(weights='DEFAULT') # 18 34 50 101 152
# model.fc = nn.Sequential()
# model = model.to(device)
# # model = torch.compile(model.to(device))
# # 32, 50:3.3-5=1.7
# # print(count_parameters(model))

# # amp res152 student cant train on batch 32, compile or not

# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
# model.eval()
# student_model.eval()
# with torch.cuda.amp.autocast():
#     with torch.no_grad(): #
#         # out = model(im) # 1:2.7, 4:6.3, 8:12.8
#         out = student_model(im) #
# # 16*15: 11.6
# 32:1.3-8.7=6.4

# student_model.train()
# out = student_model(im) #

# model.train()
# with torch.cuda.amp.autocast():
#     out = model(im) #


# # 152:3.2
# # im = torch.rand(4,3,400,680,device=device) #
# # out = model(im) # 5.9

# im = torch.rand(16,2048,device=device) #
# out = teacher_model(im) # 5.9

# torch.cuda.empty_cache()


In [None]:
# @title wandb
# https://docs.wandb.ai/quickstart
!pip install wandb
import wandb
wandb.login() # 487a2109e55dce4e13fc70681781de9f50f27be7
run = wandb.init(
    project="mpl",
    config={
        "model": "scratch 18",
        # "optim": "adamw",
        "optim": "sgd",
        # "lr": lr,
        # "epochs": epochs,
    })


In [None]:
# #title train teacher first?

In [None]:
# @title mpl grad acc
# https://github.com/kekmodel/MPL-pytorch/blob/main/main.py
import torch
from torch.cuda import amp
from torch import nn
from torch.nn import functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

t_scaler = amp.GradScaler()
s_scaler = amp.GradScaler()
def train(labeled_loader, unlabeled_loader, teacher_model, student_model,
        avg_student_model, criterion, t_optimizer, s_optimizer, t_scheduler=None, s_scheduler=None):
    labeled_iter = iter(labeled_loader)
    unlabeled_iter = iter(unlabeled_loader)
    size = len(unlabeled_loader)

    # for author's code formula
    # moving_dot_product = torch.empty(1).to(device)
    # limit = 3.0**(0.5)  # 3 = 6 / (f_in + f_out)
    # nn.init.uniform_(moving_dot_product, -limit, limit)

    teacher_model.train()
    student_model.train()
    for step in range(len(unlabeled_loader)):

        try: cimages_l, ctargets = next(labeled_iter)
        except:
            labeled_iter = iter(labeled_loader)
            cimages_l, ctargets = next(labeled_iter)
        try: cimages_uw = next(unlabeled_iter) # images_uw, _ = next(unlabeled_iter)
        except:
            unlabeled_iter = iter(unlabeled_loader)
            cimages_uw = next(unlabeled_iter) # me
        cimages_l, ctargets = cimages_l.to(device), ctargets.to(device)
        cimages_uw = cimages_uw.to(device)
        cimages_us = trs(cimages_uw)

        t_loss_l = 0
        t_loss_wu = 0
        s_loss = 0
        s_loss_l_old = 0
        ct_logits_us = torch.empty(0, device=device)
        for images_l, targets, images_uw, images_us in zip(cimages_l.chunk(grad_acc), ctargets.chunk(grad_acc), cimages_uw.chunk(grad_acc), cimages_us.chunk(grad_acc)): # for grad acc 1/2
        # for images_l, targets, images_uw, images_us, s_loss_l_old in zip(cimages_l.chunk(grad_acc), ctargets.chunk(grad_acc), cimages_uw.chunk(grad_acc), cimages_us.chunk(grad_acc), cs_loss_l_old.chunk(grad_acc)): # for grad acc 1/2
            with amp.autocast():
                batch_size = images_l.shape[0]
                with torch.no_grad(): m_logits_l = model(images_l) # reduced pml # big teacher
                t_logits_l = teacher_model(m_logits_l)
                t_l_l = criterion(t_logits_l, targets)
            # t_scaler.scale(t_loss_l).backward(retain_graph=True) # me backward first
            t_scaler.scale(t_l_l).backward() # me backward first
            t_loss_l += t_l_l
            # del t_loss_l

            with amp.autocast():
                with torch.no_grad(): s_logits_l = student_model(images_l)
                s_l_l_old = F.cross_entropy(s_logits_l, targets)
                s_loss_l_old += s_l_l_old

            for i_us, i_uw in zip(images_us.chunk(mu), images_uw.chunk(mu)):
                with amp.autocast():
                    with torch.no_grad():
                        m_i_us = model(i_us) # reduced pml # big teacher
                        m_i_uw = model(i_uw) # reduced pml # big teacher
                        t_l_uw = teacher_model(m_i_uw) # t_logits_uw no need grad
                    t_l_us = teacher_model(m_i_us)
                    ct_logits_us = torch.cat((ct_logits_us, t_l_us))

                    temperature = 0.7 # default 1 / mainargs 0.7
                    # soft_pseudo_label = torch.softmax(t_logits_uw.detach() / temperature, dim=-1)
                    soft_pseudo_label = torch.softmax(t_l_uw / temperature, dim=-1)
                    max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1) # all no grads

                    threshold = 0.6 # default 0.95 / mainargs 0.6
                    mask = max_probs.ge(threshold).float()
                    # print((soft_pseudo_label * torch.log_softmax(t_l_us, dim=-1)).sum(dim=-1) , mask)
                    t_loss_u = torch.mean(-(soft_pseudo_label * torch.log_softmax(t_l_us, dim=-1)).sum(dim=-1) * mask)
                    lambda_u = 8 # default 1 / mainargs 8 coefficient of unlabeled loss
                    uda_steps = 10 # default 1 / mainargs 5000 warmup steps of lambda-u
                    weight_u = lambda_u * min(1., (step + 1) / uda_steps) # >0
                    # t_loss_uda = t_loss_l + weight_u * t_loss_u

                    # i_us.retain_grad()
                    i_us.requires_grad=True
                    s_l_us = student_model(i_us)
                    s_l = criterion(s_l_us, hard_pseudo_label)
                s_scaler.scale(s_l).backward()
                s_loss += s_l
                t_scaler.scale(weight_u * t_loss_u).backward(retain_graph=True)
                # t_scaler.scale(weight_u * t_loss_u).backward()
                t_loss_wu += weight_u * t_loss_u
                # del t_loss_u, soft_pseudo_label# t_loss_l
                # del s_logits_l, s_logits_us, hard_pseudo_label#, s_loss

        t_loss_uda = t_loss_l + t_loss_wu
        print('t_loss_l',t_loss_l.item())
        # print('t_loss_wu',t_loss_wu.item())
        print('t_loss_wu',weight_u, t_loss_u.item())

        # if grad_clip > 0:
        s_scaler.unscale_(s_optimizer)
        nn.utils.clip_grad_norm_(student_model.parameters(), 1e9)
        s_scaler.step(s_optimizer)
        s_scaler.update()
        if s_scheduler: s_scheduler.step()
        student_model.zero_grad()
        if ema > 0: avg_student_model.update_parameters(student_model)


        s_loss_l_new = 0
        # for images_l, targets, images_uw, images_us in zip(cimages_l.chunk(grad_acc), ctargets.chunk(grad_acc), cimages_uw.chunk(grad_acc), cimages_us.chunk(grad_acc)): # for grad acc 2/2
        # for images_l, targets, images_uw, images_us, t_logits_us in zip(cimages_l.chunk(grad_acc), ctargets.chunk(grad_acc), cimages_uw.chunk(grad_acc), cimages_us.chunk(grad_acc), ct_logits_us.chunk(grad_acc)): # for grad acc 2/2
        for images_l, targets in zip(cimages_l.chunk(grad_acc), ctargets.chunk(grad_acc)): # for grad acc 2/2
            with amp.autocast():
                with torch.no_grad(): s_logits_l = student_model(images_l)
                s_l_l_new = F.cross_entropy(s_logits_l, targets)
                s_loss_l_new += s_l_l_new
        dot_product = s_loss_l_old - s_loss_l_new
        # # moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01
        # # dot_product = dot_product - moving_dot_product

        t_loss_mpl = 0
        # for images_l, targets, images_uw, images_us, t_logits_us in zip(cimages_l.chunk(grad_acc), ctargets.chunk(grad_acc), cimages_uw.chunk(grad_acc), cimages_us.chunk(grad_acc), ct_logits_us.chunk(grad_acc)): # for grad acc 2/2
        for t_logits_us in ct_logits_us.chunk(grad_acc): # for grad acc 2/2
            for t_l_us in t_logits_us.chunk(mu):
                with amp.autocast():
                    _, hard_pseudo_label = torch.max(t_l_us.detach(), dim=-1)
                    t_l_mpl = dot_product * F.cross_entropy(t_l_us, hard_pseudo_label) # dot_product no grad
                t_scaler.scale(t_l_mpl).backward(retain_graph=True)
                t_loss_mpl += t_l_mpl

        t_loss = t_loss_uda + t_loss_mpl
        print('t_loss_mpl',t_loss_mpl.item())
        # print("t_loss, s_loss", t_loss.item(), s_loss.item())
        if step%10==0: print(step,"/",size," ", "t_loss: ", t_loss.item(), "s_loss: ", s_loss.item())
        try: wandb.log({"t_loss": t_loss.item(), "s_loss": s_loss.item()})
        except: pass
        # del s_logits_l, dot_product, s_loss_l_old, s_loss_l_new, t_logits_us, hard_pseudo_label, #t_loss, t_loss_uda, t_loss_mpl
        # del s_loss


        t_scaler.unscale_(t_optimizer) # if grad_clip > 0:
        nn.utils.clip_grad_norm_(teacher_model.parameters(), 1e9)
        t_scaler.step(t_optimizer)
        t_scaler.update()
        if t_scheduler: t_scheduler.step()
        teacher_model.zero_grad()

    return



In [None]:
# @title kaggle mpl
# https://www.kaggle.com/code/hengck23/playground-for-meta-pseudo-label
# https://www.kaggle.com/code/conjuring92/nbme-meta-pseudo-labels

# ------- Training Loop  ------------------------------------------#
for step in range(num_steps):

    #------ Reset buffers After Validation ------------------------#
    if step % config["validation_interval"] == 0:
        progress_bar = tqdm(range(min(config["validation_interval"], num_steps)))
        s_loss_meter = AverageMeter()
        t_loss_meter = AverageMeter()

    teacher.train()
    student.train()

    t_optimizer.zero_grad()
    s_optimizer.zero_grad()

    #------ Get Train & Unlabelled Batch -------------------------#
    try: train_b = train_iter.next()
    except Exception as e: train_b = next(train_dl.__iter__())
    try: unlabelled_b = unlabelled_iter.next()
    except: unlabelled_b = next(unlabelled_dl.__iter__())

    #------- Meta Training Steps ---------------------------------#
    # get loss of current student on labelled train data
    s_logits_train_b = student.get_logits(train_b)

    # get loss of current student on labelled train data
    train_b_labels = train_b["labels"]
    train_b_masks = train_b_labels.gt(-0.5)
    s_loss_train_b = student.compute_loss(logits=s_logits_train_b.detach(), labels=train_b_labels, masks=train_b_masks,)

    # get teacher generated pseudo labels for unlabelled data
    unlabelled_b_masks = unlabelled_b["label_mask"].eq(1).unsqueeze(-1)
    t_logits_unlabelled_b = teacher.get_logits(unlabelled_b)
    pseudo_y_unlabelled_b = (t_logits_unlabelled_b.detach() > 0).float()  # hard pseudo label

    #------ Train Student: With Pesudo Label Data ------------------#
    s_logits_unlabelled_b = student.get_logits(unlabelled_b)
    s_loss_unlabelled_b = student.compute_loss(logits=s_logits_unlabelled_b, labels=pseudo_y_unlabelled_b, masks=unlabelled_b_masks)

    # backpropagation of student loss on unlabelled data
    accelerator.backward(s_loss_unlabelled_b)
    s_optimizer.step()  # update student params
    s_scheduler.step()

    #------ Train Teacher ------------------------------------------#
    s_logits_train_b_new = student.get_logits(train_b)
    s_loss_train_b_new = student.compute_loss(logits=s_logits_train_b_new.detach(), labels=train_b_labels, masks=train_b_masks,)
    change = s_loss_train_b_new - s_loss_train_b  # performance improvement from student

    t_logits_train_b = teacher.get_logits(train_b)
    t_loss_train_b = teacher.compute_loss(logits=t_logits_train_b, labels=train_b_labels, masks=train_b_masks)

    t_loss_mpl = change * F.binary_cross_entropy_with_logits(t_logits_unlabelled_b, pseudo_y_unlabelled_b, reduction='none')  # mpl loss
    t_loss_mpl = torch.masked_select(t_loss_mpl, unlabelled_b_masks).mean()
    t_loss = t_loss_train_b + t_loss_mpl

    # backpropagation of teacher's loss
    accelerator.backward(t_loss)
    t_optimizer.step()
    t_scheduler.step()

    #------ Progress Bar Updates ----------------------------------#
    s_loss_meter.update(s_loss_train_b_new.item())
    t_loss_meter.update(t_loss.item())

    progress_bar.set_description(f"STEP: {step+1:5}/{num_steps:5}. "f"LR: {get_lr(s_optimizer):.4f}. "f"TL: {t_loss_meter.avg:.4f}. "f"SL: {s_loss_meter.avg:.4f}. ")
    progress_bar.update()

    #------ Evaluation & Checkpointing -----------------------------#
    if (step + 1) % config["validation_interval"] == 0:
        progress_bar.close()

        #----- Teacher Evaluation  ---------------------------------#
        teacher.eval()
        teacher_preds = []
        with torch.no_grad():
            for batch in valid_dl:
                p = teacher.get_logits(batch)
                teacher_preds.append(p)
        teacher_preds = [torch.sigmoid(p).detach().cpu().numpy()[:, :, 0] for p in teacher_preds]
        teacher_preds = list(chain(*teacher_preds))
        teacher_lb = scorer_fn(teacher_preds)
        print(f"After step {step+1} Teache LB: {teacher_lb}")

        # save teacher
        accelerator.wait_for_everyone()
        teacher = accelerator.unwrap_model(teacher)
        teacher_state = {'step': step + 1, 'state_dict': teacher.state_dict(), 'optimizer': t_optimizer.state_dict(), 'lb': teacher_lb}
        is_best = False
        if teacher_lb > best_teacher_score:
            best_teacher_score = teacher_lb
            is_best = True
        # save_checkpoint(config, teacher_state, is_teacher=True, is_best=is_best)

        #----- Student Evaluation  ---------------------------------#
        student.eval()
        student_preds = []
        with torch.no_grad():
            for batch in valid_dl:
                p = student.get_logits(batch)
                student_preds.append(p)
        student_preds = [torch.sigmoid(p).detach().cpu().numpy()[:, :, 0] for p in student_preds]
        student_preds = list(chain(*student_preds))
        student_lb = scorer_fn(student_preds)
        print(f"After step {step+1} Student LB: {student_lb}")

        # save student
        accelerator.wait_for_everyone()
        student = accelerator.unwrap_model(student)

        student_state = {'step': step + 1, 'state_dict': student.state_dict(), 'optimizer': s_optimizer.state_dict(), 'lb': student_lb}
        is_best = False
        if student_lb > best_student_score:
            best_student_score = student_lb
            is_best = True
        save_checkpoint(config, student_state, is_teacher=False, is_best=is_best)



In [None]:
# @title strain eval

def evaluate(test_loader, model, criterion, verbose=True):
    size = len(test_loader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for step, (images, targets) in enumerate(test_loader):
            batch_size = images.shape[0]
            images = images.to(device)
            targets = targets.to(device)
            with amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            test_loss += loss.item()
        #     acc1, acc5 = accuracy(outputs, targets, (1, 5))
        #     losses, top1, top5 = loss.item(), acc1[0], acc5[0]
        # # return losses, top1, top5
            correct += (outputs.argmax(1) == targets).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    try: wandb.log({"test loss": test_loss})
    except: pass
    if verbose: print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
    return correct, test_loss


# @title train test function
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
scaler = torch.cuda.amp.GradScaler()
# https://github.com/prigoyal/pytorch_memonger/blob/master/models/optimized/resnet_new.py
from torch.utils.checkpoint import checkpoint, checkpoint_sequential

trs=TrainTransform() # for image augmentation during train time
# train function with automatic mixed precision
def strain(dataloader, model, loss_fn, optimizer, scheduler=None, verbose=True):
    size = len(dataloader.dataset)
    model.train()
    loss_list = []
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast(): # automatic mixed percision
            x = trs(x) # image augmentation during train time to use gpu
            pred = model(x) # default
            loss = loss_fn(pred, y)
        scaler.scale(loss).backward()
        if ((batch + 1) % 4 == 0) or (batch + 1 == len(dataloader)): # gradient accumulation
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()
                # print("### lr: ", optimizer.param_groups[0]["lr"])

        # print(model.state_dict()['_orig_mod.bn1.running_mean'][0])
        train_loss = loss.item()/len(y)
        loss_list.append(train_loss)
        try: wandb.log({"train loss": train_loss})
        except: pass
        # if batch % (size//(10* len(y))) == 0:
        current = batch * len(x)
        if verbose: print(f"loss: {train_loss:>7f}  [{current:>5d}/{size:>5d}]")
    return loss_list



In [None]:
# @title save

# from google.colab import drive
# drive.mount('/content/drive')


# checkpoint = {
# 'epoch': t+1,
# 'teacher_model': teacher_model.state_dict(),
# 'student_model': student_model.state_dict(),
# 'avg_student_model': avg_student_model.state_dict(),
# 't_optimizer': t_optimizer.state_dict(),
# 's_optimizer': s_optimizer.state_dict(),
# 't_scheduler': t_scheduler.state_dict(),
# 's_scheduler': s_scheduler.state_dict(),}
# # torch.save(checkpoint, pth)
# torch.save(checkpoint, 'ckpt.pth')


In [None]:
# @title wwwwwwwwwww

criterion = create_loss_fn()

from torch import optim
# lr default 0.01/ mainargs 0.05
# og:t0.05s0.05 , psl:t1e-3s3e-4, rpsl:t
# t_optimizer = optim.SGD(teacher_parameters, lr=0.05, momentum=0.9, nesterov=True)
# s_optimizer = optim.SGD(student_parameters, lr=0.05, momentum=0.9, nesterov=True)
# t_optimizer = optim.SGD(teacher_parameters, lr=3e-4, momentum=0.9, nesterov=True)
t_optimizer = optim.SGD(teacher_parameters, lr=1e-4, momentum=0.9, nesterov=True)
s_optimizer = optim.SGD(student_parameters, lr=3e-4, momentum=0.9, nesterov=True)
# s_optimizer = optim.SGD(student_parameters, lr=1e-3, momentum=0.9, nesterov=True)
# optimizer = bnb.optim.AdamW(model.parameters(), lr=1e-5, betas=(0.9, 0.999), optim_bits=8)
# optimizer = Lamb(model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=3e-6)

# res18 batch64 sgd from scratch 3e-5 - 3e-4 - 1e-3?

# 3e-5 27.9->25.0
# 3e-3 27.9->26.5
# 0.05 29.4 -> 22.1

epochs = 5
num_batches=len(train_loader)
total_steps=int(np.ceil(num_batches/grad_acc)*epochs +1) # +1 to excluse uptick at the end of onecycle
# total_steps=100 # 300000
warmup_steps = 50 # default 0 / mainargs 5000
# t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, warmup_steps, total_steps)
t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, warmup_steps, total_steps,10)
student_wait_steps = 30 # default 0 / mainargs 3000
s_scheduler = get_cosine_schedule_with_warmup(s_optimizer, warmup_steps, total_steps, student_wait_steps)



In [None]:

# scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=int(np.ceil(num_batches/4)*3), power=1.0)
# scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=10**(-1/2))


# t_optimizer.param_groups[0]["lr"]
# s_optimizer.param_groups[0]["lr"]

# import time
# start = time.time()

pth='/content/mpl.pth' # ty
# pth='/content/drive/MyDrive/frame/mpl18.pth' # M

for t in range(0,epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    t_lr=t_optimizer.param_groups[0]["lr"]
    s_lr=s_optimizer.param_groups[0]["lr"]
    print('t_lr,s_lr',t_lr,s_lr)
    # train(labeled_loader, unlabeled_loader, teacher_model, student_model,
    #     avg_student_model, criterion, t_optimizer, s_optimizer, t_scheduler, s_scheduler)
    train(labeled_loader, unlabeled_loader, teacher_model, student_model,
        avg_student_model, criterion, t_optimizer, s_optimizer)

    # evaluate(test_loader, student_model, criterion)
    evaluate(test_loader, avg_student_model, criterion)

    checkpoint = {
    'epoch': t+1,
    'teacher_model': teacher_model.state_dict(),
    'student_model': student_model.state_dict(),
    'avg_student_model': avg_student_model.state_dict(),
    't_optimizer': t_optimizer.state_dict(),
    's_optimizer': s_optimizer.state_dict(),
    't_scheduler': t_scheduler.state_dict(),
    's_scheduler': s_scheduler.state_dict(),}
    torch.save(checkpoint, pth)


# res34, batch4 8.8
# res18, batch16 11.6 nocompilemodel

# 16m28s



In [None]:

# finetune
# model = student_model
model = avg_student_model
model.drop = nn.Identity()
# labeled_loader = DataLoader(finetune_dataset, batch_size=128, num_workers=4, pin_memory=True) # batch_size=512
labeled_loader = DataLoader(finetune_dataset, batch_size=128, pin_memory=True) # batch_size=512
optimizer = optim.SGD(model.parameters(), lr=3e-5, momentum=0.9, weight_decay=0, nesterov=True)
# scaler = amp.GradScaler()
for epoch in range(1): #625
    # train_ls = strain(labeled_loader, model, loss_fn, optimizer, scheduler)
    train_ls = strain(labeled_loader, model, criterion, optimizer)
    evaluate(test_loader, student_model, criterion)


In [None]:
correct, test_loss = evaluate(test_loader, student_model, criterion)
print(correct, test_loss)