In [1]:
!jupyter nbconvert --to script BYOL.ipynb
import BYOL
import torch
import time
from pretrain_dataloader import *
from lars import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torch.nn.functional as F
torch.backends.cudnn.enabled = True
import copy
import wandb

[NbConvertApp] Converting notebook BYOL.ipynb to script
[NbConvertApp] Writing 2969 bytes to BYOL.py


  stdout_func(


In [2]:
args = return_default_args()
train_loader = prepare_cifar_train_loader(args)
online = BYOL.BYOL_module().to('cuda').to(memory_format=torch.channels_last)
target = copy.deepcopy(online).to('cuda').to(memory_format=torch.channels_last)

Files already downloaded and verified


In [3]:
params = [
            {"name": "backbone", "params": online.network.parameters()},
            {
                "name": "classifier",
                "params": online.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "momentum_classifier",
                "params": target.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {'params': online.projector.parameters()},
            {'params':online.predictor.parameters(),}
        ]


opt = torch.optim.SGD(params, lr = .3, weight_decay = 1.5e-5, momentum = .9)
opt = LARSWrapper(opt, clip = True, exclude_bias_n_norm = True, eta = .001)
schedule = LinearWarmupCosineAnnealingLR(
                    opt,
                    warmup_epochs= 10 * 195,
                    max_epochs= 195 * 200,
                    warmup_start_lr=3e-05,
                    eta_min=0)


In [4]:
def training_step(data, online, target, temporal_ensembling):
    step_metrics = {}
    
    online1 = online(data[1][0])
    online2 = online(data[1][1])
    target1 = target.momentum_forward(data[1][0])
    target2 = target.momentum_forward(data[1][1])
    
    with torch.no_grad():
        average = (target1['z'] + target2['z']) / torch.tensor(2, device = device)
        temporal = temporal_ensembling[data[0]].to(device)
        temp_average = (temporal[:, 0, :] + temporal[:, 1, :]) / 2

        final_average = (temp_average + average) / 2
        temporal_ensembling[data[0], 0, :] = temporal_ensembling[data[0], 0, :] * .75 + .25 * target1['z'].cpu()
        temporal_ensembling[data[0], 1, :] = temporal_ensembling[data[0], 1, :] * .75 + .25 * target2['z'].cpu()
    
    #byol_loss
    step_metrics["byol_loss"] = byol_loss_func(online1['p'], final_average)
    step_metrics["byol_loss"] += byol_loss_func(online2['p'], final_average)
    
    #cross entropy loss
    step_metrics["online_cross_entropy_loss"] = F.cross_entropy(online1['logits'], data[2], ignore_index=-1)
    step_metrics["momentum_cross_entropy_loss"] = F.cross_entropy(target1['logits'], data[2], ignore_index=-1)
    
    #accuracy of predictions
    _, predicted = torch.max(online1['logits'], 1)
    step_metrics["online_acc1"] = (predicted == data[2]).sum()
    _, predicted = torch.max(target1['logits'], 1)
    step_metrics["target_acc1"] = (predicted == data[2]).sum()
    
    _, pred = online1['logits'].topk(5)
    data[2] = data[2].unsqueeze(1).expand_as(pred)
    step_metrics["online_acc5"] = (data[2] == pred).any(dim = 1).sum()
    _, pred = target1['logits'].topk(5)
    step_metrics["target_acc5"] = (data[2] == pred).any(dim = 1).sum()
                   
        
    
    
    #metrics to track
    with torch.no_grad():
        online1p_softmax = F.softmax(online1['p'], dim = 1)
        online2p_softmax = F.softmax(online2['p'], dim = 1)
        online1rep_softmax = F.softmax(online1['Representation'], dim = 1)
        online2rep_softmax = F.softmax(online2['Representation'], dim = 1)
        online1z_softmax = F.softmax(online1['z'], dim = 1)
        online2z_softmax = F.softmax(online2['z'], dim = 1)
        
        target1z_softmax = F.softmax(target1['z'], dim = 1)
        target2z_softmax = F.softmax(target1['z'], dim = 1)
        target1rep_softmax = F.softmax(target1['Representation'], dim = 1)
        target2rep_softmax = F.softmax(target2['Representation'], dim = 1)
        
        
        cross_entropy = F.cross_entropy(online1p_softmax, target2z_softmax, ignore_index=-1) + F.cross_entropy(online2p_softmax, target1z_softmax, ignore_index=-1)
        l1_dist = F.l1_loss(online1['p'], target2['z']) + F.l1_loss(online2['p'], target1['z'])
        l2_dist = F.mse_loss(online1['p'], target2['z']) + F.mse_loss(online2['p'], target1['z'])
        smooth_l1 = F.smooth_l1_loss(online1['p'], target2['z']) + F.smooth_l1_loss(online2['p'], target1['z'])
        kl_div = F.kl_div(online1p_softmax, target2z_softmax) + F.kl_div(online2p_softmax, target1z_softmax)

        representation_l1 = F.l1_loss(online1['Representation'], target2['Representation']) + F.l1_loss(online2['Representation'], target1['Representation'])
        representation_l2 = F.mse_loss(online1['Representation'], target2['Representation']) + F.mse_loss(online2['Representation'], target1['Representation'])
        representation_cross_entropy = F.cross_entropy(online1rep_softmax, target2rep_softmax, ignore_index=-1) + F.cross_entropy(online2rep_softmax, target1rep_softmax, ignore_index=-1)
        representation_kl = F.kl_div(online1rep_softmax, target2rep_softmax) + F.kl_div(online2rep_softmax, target1rep_softmax)
        representation_cos_sim = F.cosine_similarity(online1['Representation'], target2['Representation']) + F.cosine_similarity(online2['Representation'], target1['Representation'])

        projection_l1 = F.l1_loss(online1['z'], target2['z']) + F.l1_loss(online2['z'], target1['z'])
        projection_l2 = F.mse_loss(online1['z'], target2['z']) + F.mse_loss(online2['z'], target1['z'])
        projection_cross_entropy = F.cross_entropy(online1z_softmax, target2z_softmax, ignore_index=-1) + F.cross_entropy(online2z_softmax, target1z_softmax, ignore_index=-1)
        projection_kl = F.kl_div(online1z_softmax, target2z_softmax) + F.kl_div(online2z_softmax, target1z_softmax)
        projection_cos_sim = F.cosine_similarity(online1['z'], target2['z']) + F.cosine_similarity(online2['z'], target1['z'])

        momentum_projection_cos_sim = F.cosine_similarity(target1['z'], target2['z'])
        momentum_representation_cos_sim = F.cosine_similarity(target1['Representation'], target2['Representation'])
        momentum_representation_l2 =  F.mse_loss(target1['Representation'], target2['Representation'])
        momentum_projection_l2 =  F.mse_loss(target1['z'], target2['z'])

        online_projection_cos_sim = F.cosine_similarity(online1['z'], online2['z'])
        online_representation_cos_sim = F.cosine_similarity(online1['Representation'], online2['Representation'])
        online_representation_l2 =  F.mse_loss(online1['Representation'], online2['Representation'])
        online_projection_l2 =  F.mse_loss(online1['z'], online2['z'])
        
        online_representation_std = torch.mean(torch.std(online1['Representation']))
        online_projection_std = torch.mean(torch.std(online1['z']))
        online_prediction_std = torch.mean(torch.std(online1['p']))
        target_representation_std = torch.mean(torch.std(target1['Representation']))
        target_projection_std = torch.mean(torch.std(target1['z']))

        step_metrics.update({
            "train_feats_cross_entropy": cross_entropy,
            "train_feats_l1_dist": l1_dist,
            "train_feats_l2_dist": l2_dist,
            "train_feats_smooth_l1": smooth_l1,
            "train_feats_kl_div": kl_div,
            "representation_l1": representation_l1,
            "representation_l2": representation_l1,
            "representation_cross_entropy": representation_cross_entropy,
            "representation_kl": representation_kl,
            "representation_cos_sim": representation_cos_sim.mean(),
            "projection_l1": projection_l1,
            "projection_l2": projection_l2,
            "projection_cross_entropy": projection_cross_entropy,
            "projection_kl": projection_kl,
            "projection_cos_sim": projection_cos_sim.mean(),
            "momentum_projection_cos_sim": momentum_projection_cos_sim.mean(),
            "momentum_representation_cos_sim": momentum_representation_cos_sim.mean(),
            "momentum_representation_l2": momentum_representation_l2,
            "momentum_representation_l2": momentum_projection_l2,
            "online_projection_cos_sim": online_projection_cos_sim.mean(),
            "online_representation_cos_sim": online_representation_cos_sim.mean(),
            "online_representation_l2": online_representation_l2,
            "online_projection_l2": online_projection_l2,
            "online_representation_std": online_representation_std,
            "online_projection_std": online_projection_std,
            "online_prediction_std": online_prediction_std,
            "target_representation_std": target_representation_std,
            "target_projection_std": target_projection_std,
            })
    return step_metrics["byol_loss"] + step_metrics["online_cross_entropy_loss"] + \
            step_metrics["momentum_cross_entropy_loss"], step_metrics




In [5]:
log = True
name = 'BYOL - Temp Ensemble .75 Old 99 Momentum'
if log:
    wandb.init(config = args.__dict__, name = name)

[34m[1mwandb[0m: Currently logged in as: [33mdcaustin33[0m (use `wandb login --relogin` to force relogin)


In [None]:
step = 0
scaler = torch.cuda.amp.GradScaler()
temporal_ensembling = torch.randn(50001, 2, 256) #in case the indexs are not zero indexed

for e in range(200):
    metrics = {
        "byol_loss": 0,
        'online_cross_entropy_loss': 0,
        'momentum_cross_entropy_loss': 0,
        'online_acc1': 0,
        'target_acc1': 0,
        'online_acc5': 0,
        'target_acc5': 0,
        "train_feats_cross_entropy": 0,
        "train_feats_l1_dist": 0,
        "train_feats_l2_dist": 0,
        "train_feats_smooth_l1": 0,
        "train_feats_kl_div": 0,
        "representation_l1": 0,
        "representation_l2": 0,
        "representation_cross_entropy": 0,
        "representation_kl": 0,
        "representation_cos_sim": 0,
        "projection_l1": 0,
        "projection_l2": 0,
        "projection_cross_entropy": 0,
        "projection_kl": 0,
        "projection_cos_sim": 0,
        "momentum_projection_cos_sim": 0,
        "momentum_representation_cos_sim": 0,
        "momentum_representation_l2": 0,
        "momentum_representation_l2": 0,
        "online_projection_cos_sim": 0,
        "online_representation_cos_sim": 0,
        "online_representation_l2": 0,
        "online_projection_l2": 0,
        "online_representation_std": 0,
        "online_projection_std": 0,
        "online_prediction_std": 0,
        "target_representation_std": 0,
        "target_projection_std": 0,
        }
    
    
    total_loss = 0
    now = time.time()
    for i_, data in enumerate(train_loader):
        
        with torch.autocast('cuda'):
            step += 1
            data[0] = data[0].to('cuda')
            data[1][0] = data[1][0].to('cuda').to(memory_format=torch.channels_last)
            data[1][1] = data[1][1].to('cuda').to(memory_format=torch.channels_last)
            data[2] = data[2].to('cuda')
            
            
            loss, step_metrics = training_step(data, online, target, temporal_ensembling)
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            schedule.step()
            scaler.update()

            update_target_params(online.network.parameters(), target.network.parameters(), .99)
            update_target_params(online.projector.parameters(), target.projector.parameters(), .99)
                
        total_loss += loss
        for key in step_metrics:
            metrics[key] += step_metrics[key]

    print(e, time.time() - now, total_loss / (i_+1), metrics['byol_loss'] / (i_+1), 
          metrics['online_acc1'] / 50000, metrics['target_acc1'] / 50000)
    if log:
        i_ += 1
        for key in metrics:
            if key[-4:-1] == 'acc':
                metrics[key] = metrics[key] / 50000
            else:
                metrics[key] = metrics[key]/i_
        wandb.log(metrics)
    checkpoint = {
    'online': online.state_dict(),
    'target': target.state_dict(),
    'epoch': e,
    'optimizer': opt.optim.state_dict()
    }
    torch.save(checkpoint, 'checkpoint.pth')
    



0 56.514204025268555 tensor(11.7800, device='cuda:0', grad_fn=<DivBackward0>) tensor(2.6434, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0194, device='cuda:0') tensor(0.0177, device='cuda:0')
1 56.81884527206421 tensor(10.3899, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.5560, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0375, device='cuda:0') tensor(0.0354, device='cuda:0')
2 56.344401597976685 tensor(9.5313, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.9106, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0532, device='cuda:0') tensor(0.0489, device='cuda:0')
3 57.230048418045044 tensor(9.0621, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6225, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0657, device='cuda:0') tensor(0.0635, device='cuda:0')
4 56.36947011947632 tensor(8.7656, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.5016, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0780, device='cuda:0') tensor(0.0755, device='cuda:0')
5 56.4196512699

42 56.550097703933716 tensor(5.6658, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.2643, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3271, device='cuda:0') tensor(0.3325, device='cuda:0')
43 56.17641830444336 tensor(5.6385, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.2652, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3261, device='cuda:0') tensor(0.3301, device='cuda:0')
44 55.895565032958984 tensor(5.5752, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.2639, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3338, device='cuda:0') tensor(0.3382, device='cuda:0')
45 56.13703274726868 tensor(5.5565, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.2642, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3340, device='cuda:0') tensor(0.3385, device='cuda:0')
46 56.07009291648865 tensor(5.5259, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.2664, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3372, device='cuda:0') tensor(0.3418, device='cuda:0')
47 56.0943269

In [None]:
torch.save(online.state_dict(), name)

In [None]:
args = return_default_args()
train_loader = prepare_cifar_train_loader(args)
online = BYOL.BYOL_module().to('cuda').to(memory_format=torch.channels_last)
target = copy.deepcopy(online).to('cuda').to(memory_format=torch.channels_last)

In [None]:
params = [
            {"name": "backbone", "params": online.network.parameters()},
            {
                "name": "classifier",
                "params": online.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "momentum_classifier",
                "params": target.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {'params': online.projector.parameters()},
            {'params':online.predictor.parameters(),}
        ]


opt = torch.optim.SGD(params, lr = .3, weight_decay = 1.5e-5, momentum = .9)
opt = LARSWrapper(opt, clip = True, exclude_bias_n_norm = True, eta = .001)
schedule = LinearWarmupCosineAnnealingLR(
                    opt,
                    warmup_epochs= 10 * 195,
                    max_epochs= 195 * 200,
                    warmup_start_lr=3e-05,
                    eta_min=0)


In [None]:
def training_step(data, online, target, temporal_ensembling):
    step_metrics = {}
    
    online1 = online(data[1][0])
    online2 = online(data[1][1])
    target1 = target.momentum_forward(data[1][0])
    target2 = target.momentum_forward(data[1][1])
    
    with torch.no_grad():
        average = (target1['z'] + target2['z']) / torch.tensor(2, device = device)
        temporal = temporal_ensembling[data[0]].to(device)
        temp_average = (temporal[:, 0, :] + temporal[:, 1, :]) / 2

        final_average = (temp_average + average) / 2
        temporal_ensembling[data[0], 0, :] = temporal_ensembling[data[0], 0, :] * .25 + .75 * target1['z'].cpu()
        temporal_ensembling[data[0], 1, :] = temporal_ensembling[data[0], 1, :] * .25 + .75 * target2['z'].cpu()
    
    #byol_loss
    step_metrics["byol_loss"] = byol_loss_func(online1['p'], final_average)
    step_metrics["byol_loss"] += byol_loss_func(online2['p'], final_average)
    
    #cross entropy loss
    step_metrics["online_cross_entropy_loss"] = F.cross_entropy(online1['logits'], data[2], ignore_index=-1)
    step_metrics["momentum_cross_entropy_loss"] = F.cross_entropy(target1['logits'], data[2], ignore_index=-1)
    
    #accuracy of predictions
    _, predicted = torch.max(online1['logits'], 1)
    step_metrics["online_acc1"] = (predicted == data[2]).sum()
    _, predicted = torch.max(target1['logits'], 1)
    step_metrics["target_acc1"] = (predicted == data[2]).sum()
    
    _, pred = online1['logits'].topk(5)
    data[2] = data[2].unsqueeze(1).expand_as(pred)
    step_metrics["online_acc5"] = (data[2] == pred).any(dim = 1).sum()
    _, pred = target1['logits'].topk(5)
    step_metrics["target_acc5"] = (data[2] == pred).any(dim = 1).sum()
                   
        
    
    
    #metrics to track
    with torch.no_grad():
        online1p_softmax = F.softmax(online1['p'], dim = 1)
        online2p_softmax = F.softmax(online2['p'], dim = 1)
        online1rep_softmax = F.softmax(online1['Representation'], dim = 1)
        online2rep_softmax = F.softmax(online2['Representation'], dim = 1)
        online1z_softmax = F.softmax(online1['z'], dim = 1)
        online2z_softmax = F.softmax(online2['z'], dim = 1)
        
        target1z_softmax = F.softmax(target1['z'], dim = 1)
        target2z_softmax = F.softmax(target1['z'], dim = 1)
        target1rep_softmax = F.softmax(target1['Representation'], dim = 1)
        target2rep_softmax = F.softmax(target2['Representation'], dim = 1)
        
        
        cross_entropy = F.cross_entropy(online1p_softmax, target2z_softmax, ignore_index=-1) + F.cross_entropy(online2p_softmax, target1z_softmax, ignore_index=-1)
        l1_dist = F.l1_loss(online1['p'], target2['z']) + F.l1_loss(online2['p'], target1['z'])
        l2_dist = F.mse_loss(online1['p'], target2['z']) + F.mse_loss(online2['p'], target1['z'])
        smooth_l1 = F.smooth_l1_loss(online1['p'], target2['z']) + F.smooth_l1_loss(online2['p'], target1['z'])
        kl_div = F.kl_div(online1p_softmax, target2z_softmax) + F.kl_div(online2p_softmax, target1z_softmax)

        representation_l1 = F.l1_loss(online1['Representation'], target2['Representation']) + F.l1_loss(online2['Representation'], target1['Representation'])
        representation_l2 = F.mse_loss(online1['Representation'], target2['Representation']) + F.mse_loss(online2['Representation'], target1['Representation'])
        representation_cross_entropy = F.cross_entropy(online1rep_softmax, target2rep_softmax, ignore_index=-1) + F.cross_entropy(online2rep_softmax, target1rep_softmax, ignore_index=-1)
        representation_kl = F.kl_div(online1rep_softmax, target2rep_softmax) + F.kl_div(online2rep_softmax, target1rep_softmax)
        representation_cos_sim = F.cosine_similarity(online1['Representation'], target2['Representation']) + F.cosine_similarity(online2['Representation'], target1['Representation'])

        projection_l1 = F.l1_loss(online1['z'], target2['z']) + F.l1_loss(online2['z'], target1['z'])
        projection_l2 = F.mse_loss(online1['z'], target2['z']) + F.mse_loss(online2['z'], target1['z'])
        projection_cross_entropy = F.cross_entropy(online1z_softmax, target2z_softmax, ignore_index=-1) + F.cross_entropy(online2z_softmax, target1z_softmax, ignore_index=-1)
        projection_kl = F.kl_div(online1z_softmax, target2z_softmax) + F.kl_div(online2z_softmax, target1z_softmax)
        projection_cos_sim = F.cosine_similarity(online1['z'], target2['z']) + F.cosine_similarity(online2['z'], target1['z'])

        momentum_projection_cos_sim = F.cosine_similarity(target1['z'], target2['z'])
        momentum_representation_cos_sim = F.cosine_similarity(target1['Representation'], target2['Representation'])
        momentum_representation_l2 =  F.mse_loss(target1['Representation'], target2['Representation'])
        momentum_projection_l2 =  F.mse_loss(target1['z'], target2['z'])

        online_projection_cos_sim = F.cosine_similarity(online1['z'], online2['z'])
        online_representation_cos_sim = F.cosine_similarity(online1['Representation'], online2['Representation'])
        online_representation_l2 =  F.mse_loss(online1['Representation'], online2['Representation'])
        online_projection_l2 =  F.mse_loss(online1['z'], online2['z'])
        
        online_representation_std = torch.mean(torch.std(online1['Representation']))
        online_projection_std = torch.mean(torch.std(online1['z']))
        online_prediction_std = torch.mean(torch.std(online1['p']))
        target_representation_std = torch.mean(torch.std(target1['Representation']))
        target_projection_std = torch.mean(torch.std(target1['z']))

        step_metrics.update({
            "train_feats_cross_entropy": cross_entropy,
            "train_feats_l1_dist": l1_dist,
            "train_feats_l2_dist": l2_dist,
            "train_feats_smooth_l1": smooth_l1,
            "train_feats_kl_div": kl_div,
            "representation_l1": representation_l1,
            "representation_l2": representation_l1,
            "representation_cross_entropy": representation_cross_entropy,
            "representation_kl": representation_kl,
            "representation_cos_sim": representation_cos_sim.mean(),
            "projection_l1": projection_l1,
            "projection_l2": projection_l2,
            "projection_cross_entropy": projection_cross_entropy,
            "projection_kl": projection_kl,
            "projection_cos_sim": projection_cos_sim.mean(),
            "momentum_projection_cos_sim": momentum_projection_cos_sim.mean(),
            "momentum_representation_cos_sim": momentum_representation_cos_sim.mean(),
            "momentum_representation_l2": momentum_representation_l2,
            "momentum_representation_l2": momentum_projection_l2,
            "online_projection_cos_sim": online_projection_cos_sim.mean(),
            "online_representation_cos_sim": online_representation_cos_sim.mean(),
            "online_representation_l2": online_representation_l2,
            "online_projection_l2": online_projection_l2,
            "online_representation_std": online_representation_std,
            "online_projection_std": online_projection_std,
            "online_prediction_std": online_prediction_std,
            "target_representation_std": target_representation_std,
            "target_projection_std": target_projection_std,
            })
    return step_metrics["byol_loss"] + step_metrics["online_cross_entropy_loss"] + \
            step_metrics["momentum_cross_entropy_loss"], step_metrics




In [None]:
log = True
name = 'BYOL - Temp Ensemble .25 Old 99 Momentum'
if log:
    wandb.init(config = args.__dict__, name = name)

In [None]:
step = 0
scaler = torch.cuda.amp.GradScaler()
temporal_ensembling = torch.randn(50001, 2, 256) #in case the indexs are not zero indexed

for e in range(200):
    metrics = {
        "byol_loss": 0,
        'online_cross_entropy_loss': 0,
        'momentum_cross_entropy_loss': 0,
        'online_acc1': 0,
        'target_acc1': 0,
        'online_acc5': 0,
        'target_acc5': 0,
        "train_feats_cross_entropy": 0,
        "train_feats_l1_dist": 0,
        "train_feats_l2_dist": 0,
        "train_feats_smooth_l1": 0,
        "train_feats_kl_div": 0,
        "representation_l1": 0,
        "representation_l2": 0,
        "representation_cross_entropy": 0,
        "representation_kl": 0,
        "representation_cos_sim": 0,
        "projection_l1": 0,
        "projection_l2": 0,
        "projection_cross_entropy": 0,
        "projection_kl": 0,
        "projection_cos_sim": 0,
        "momentum_projection_cos_sim": 0,
        "momentum_representation_cos_sim": 0,
        "momentum_representation_l2": 0,
        "momentum_representation_l2": 0,
        "online_projection_cos_sim": 0,
        "online_representation_cos_sim": 0,
        "online_representation_l2": 0,
        "online_projection_l2": 0,
        "online_representation_std": 0,
        "online_projection_std": 0,
        "online_prediction_std": 0,
        "target_representation_std": 0,
        "target_projection_std": 0,
        }
    
    
    total_loss = 0
    now = time.time()
    for i_, data in enumerate(train_loader):
        
        with torch.autocast('cuda'):
            step += 1
            data[0] = data[0].to('cuda')
            data[1][0] = data[1][0].to('cuda').to(memory_format=torch.channels_last)
            data[1][1] = data[1][1].to('cuda').to(memory_format=torch.channels_last)
            data[2] = data[2].to('cuda')
            
            
            loss, step_metrics = training_step(data, online, target, temporal_ensembling)
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            schedule.step()
            scaler.update()

            update_target_params(online.network.parameters(), target.network.parameters(), .99)
            update_target_params(online.projector.parameters(), target.projector.parameters(), .99)
                
        total_loss += loss
        for key in step_metrics:
            metrics[key] += step_metrics[key]

    print(e, time.time() - now, total_loss / (i_+1), metrics['byol_loss'] / (i_+1), 
          metrics['online_acc1'] / 50000, metrics['target_acc1'] / 50000)
    if log:
        i_ += 1
        for key in metrics:
            if key[-4:-1] == 'acc':
                metrics[key] = metrics[key] / 50000
            else:
                metrics[key] = metrics[key]/i_
        wandb.log(metrics)
    checkpoint = {
    'online': online.state_dict(),
    'target': target.state_dict(),
    'epoch': e,
    'optimizer': opt.optim.state_dict()
    }
    torch.save(checkpoint, 'checkpoint.pth')
    

In [None]:
torch.save(online.state_dict(), name)