In [1]:
!jupyter nbconvert --to script BYOL.ipynb
import BYOL
import torch
import time
from pretrain_dataloader import *
from lars import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torch.nn.functional as F
torch.backends.cudnn.enabled = True
import copy
import wandb

[NbConvertApp] Converting notebook BYOL.ipynb to script
[NbConvertApp] Writing 2969 bytes to BYOL.py


In [2]:
class args:
    def __init__(self):
        return
args = args()

args.dataset = 'cifar100'
args.transform_kwargs=[{'brightness': 0.4, 'contrast': 0.4, 'saturation': 0.2, 'hue': 0.1, 'color_jitter_prob': 0.8, 'gray_scale_prob': 0.2, 'horizontal_flip_prob': 0.5, 'gaussian_prob': 1.0, 'solarization_prob': 0.0, 'crop_size': 32, 'min_scale': 0.08, 'max_scale': 1.0}, {'brightness': 0.4, 'contrast': 0.4, 'saturation': 0.2, 'hue': 0.1, 'color_jitter_prob': 0.8, 'gray_scale_prob': 0.2, 'horizontal_flip_prob': 0.5, 'gaussian_prob': 0.1, 'solarization_prob': 0.2, 'crop_size': 32, 'min_scale': 0.08, 'max_scale': 1.0}]
# asymmetric augmentations
args.num_crops_per_aug = [1, 1]
args.batch_size = 256
args.num_workers = 4


transform = [
    prepare_transform(args.dataset, **kwargs) for kwargs in args.transform_kwargs
]

transform = prepare_n_crop_transform(transform, num_crops_per_aug=args.num_crops_per_aug)

train_dataset = prepare_datasets(
    args.dataset,
    transform,
    no_labels=False,
)
train_loader = prepare_dataloader(
    train_dataset, batch_size=args.batch_size, num_workers=args.num_workers
)

Files already downloaded and verified


In [3]:
def update_target_params(online_params, target_params, tau):

    #update the backbone first
    for op, mp in zip(online_params, target_params):
        mp.data = tau * mp.data + (1 - tau) * op.data
    
def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor:
    return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean()

online = BYOL.BYOL_module().to('cuda').to(memory_format=torch.channels_last)
target = copy.deepcopy(online).to('cuda').to(memory_format=torch.channels_last)

In [4]:
args.optimizer = 'LARS'
args.lr = .3
args.weight_decay = 1.5e-6
args.momentum = .9
args.classifier_lr = .1
args.classifier_weight_decay = 0
args.epochs = 200
args.warmup_epochs = 10
args.steps = int((50000/args.batch_size) * args.epochs)
args.warmup_steps = int((50000/args.batch_size) * args.warmup_epochs)

params = [
            {"name": "backbone", "params": online.network.parameters()},
            {
                "name": "classifier",
                "params": online.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "momentum_classifier",
                "params": target.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {'params': online.projector.parameters()},
            {'params':online.predictor.parameters(),}
        ]


opt = torch.optim.SGD(params, lr = .3, weight_decay = 1.5e-5, momentum = .9)
opt = LARSWrapper(opt, clip = True, exclude_bias_n_norm = True, eta = .001)
schedule = LinearWarmupCosineAnnealingLR(
                    opt,
                    warmup_epochs= 10 * 195,
                    max_epochs= 195 * 200,
                    warmup_start_lr=3e-05,
                    eta_min=0)


In [5]:
def training_step(data, online, target):
    step_metrics = {}
    
    online1 = online(data[1][0])
    online2 = online(data[1][1])
    target1 = target.momentum_forward(data[1][0])
    target2 = target.momentum_forward(data[1][1])
    
    #byol_loss
    step_metrics["byol_loss"] = byol_loss_func(online1['p'], target2['z'])
    step_metrics["byol_loss"] += byol_loss_func(online2['p'], target1['z'])
    
    #cross entropy loss
    step_metrics["online_cross_entropy_loss"] = F.cross_entropy(online1['logits'], data[2], ignore_index=-1)
    step_metrics["momentum_cross_entropy_loss"] = F.cross_entropy(target1['logits'], data[2], ignore_index=-1)
    
    #accuracy of predictions
    _, predicted = torch.max(online1['logits'], 1)
    step_metrics["online_acc1"] = (predicted == data[2]).sum()
    _, predicted = torch.max(target1['logits'], 1)
    step_metrics["target_acc1"] = (predicted == data[2]).sum()
    
    _, pred = online1['logits'].topk(5)
    data[2] = data[2].unsqueeze(1).expand_as(pred)
    step_metrics["online_acc5"] = (data[2] == pred).any(dim = 1).sum()
    _, pred = target1['logits'].topk(5)
    step_metrics["target_acc5"] = (data[2] == pred).any(dim = 1).sum()
                   
        
    
    
    #metrics to track
    with torch.no_grad():
        cross_entropy = F.cross_entropy(online1['p'], target2['z'], ignore_index=-1) + F.cross_entropy(online2['p'], target1['z'], ignore_index=-1)
        l1_dist = F.l1_loss(online1['p'], target2['z']) + F.l1_loss(online2['p'], target1['z'])
        l2_dist = F.mse_loss(online1['p'], target2['z']) + F.mse_loss(online2['p'], target1['z'])
        smooth_l1 = F.smooth_l1_loss(online1['p'], target2['z']) + F.smooth_l1_loss(online2['p'], target1['z'])
        kl_div = F.kl_div(online1['p'], target2['z']) + F.kl_div(online2['p'], target1['z'])

        representation_l1 = F.l1_loss(online1['Representation'], target2['Representation']) + F.l1_loss(online2['Representation'], target1['Representation'])
        representation_l2 = F.mse_loss(online1['Representation'], target2['Representation']) + F.mse_loss(online2['Representation'], target1['Representation'])
        representation_cross_entropy = F.cross_entropy(online1['Representation'], target2['Representation'], ignore_index=-1) + F.cross_entropy(online2['Representation'], target1['Representation'], ignore_index=-1)
        representation_kl = F.kl_div(online1['Representation'], target2['Representation']) + F.kl_div(online2['Representation'], target1['Representation'])
        representation_cos_sim = F.cosine_similarity(online1['Representation'], target2['Representation']) + F.cosine_similarity(online2['Representation'], target1['Representation'])

        projection_l1 = F.l1_loss(online1['z'], target2['z']) + F.l1_loss(online2['z'], target1['z'])
        projection_l2 = F.mse_loss(online1['z'], target2['z']) + F.mse_loss(online2['z'], target1['z'])
        projection_cross_entropy = F.cross_entropy(online1['z'], target2['z'], ignore_index=-1) + F.cross_entropy(online2['z'], target1['z'], ignore_index=-1)
        projection_kl = F.kl_div(online1['z'], target2['z']) + F.kl_div(online2['z'], target1['z'])
        projection_cos_sim = F.cosine_similarity(online1['z'], target2['z']) + F.cosine_similarity(online2['z'], target1['z'])

        momentum_projection_cos_sim = F.cosine_similarity(target1['z'], target2['z'])
        momentum_representation_cos_sim = F.cosine_similarity(target1['Representation'], target2['Representation'])
        momentum_representation_l2 =  F.mse_loss(target1['Representation'], target2['Representation'])
        momentum_projection_l2 =  F.mse_loss(target1['z'], target2['z'])

        online_projection_cos_sim = F.cosine_similarity(online1['z'], online2['z'])
        online_representation_cos_sim = F.cosine_similarity(online1['Representation'], online2['Representation'])
        online_representation_l2 =  F.mse_loss(online1['Representation'], online2['Representation'])
        online_projection_l2 =  F.mse_loss(online1['z'], online2['z'])

        step_metrics.update({
            "train_feats_cross_entropy": cross_entropy,
            "train_feats_l1_dist": l1_dist,
            "train_feats_l2_dist": l2_dist,
            "train_feats_smooth_l1": smooth_l1,
            "train_feats_kl_div": kl_div,
            "representation_l1": representation_l1,
            "representation_l2": representation_l1,
            "representation_cross_entropy": representation_cross_entropy,
            "representation_kl": representation_kl,
            "representation_cos_sim": representation_cos_sim.mean(),
            "projection_l1": projection_l1,
            "projection_l2": projection_l2,
            "projection_cross_entropy": projection_cross_entropy,
            "projection_kl": projection_kl,
            "projection_cos_sim": projection_cos_sim.mean(),
            "momentum_projection_cos_sim": momentum_projection_cos_sim.mean(),
            "momentum_representation_cos_sim": momentum_representation_cos_sim.mean(),
            "momentum_representation_l2": momentum_representation_l2,
            "momentum_representation_l2": momentum_projection_l2,
            "online_projection_cos_sim": online_projection_cos_sim.mean(),
            "online_representation_cos_sim": online_representation_cos_sim.mean(),
            "online_representation_l2": online_representation_l2,
            "online_projection_l2": online_projection_l2,
            })
    return step_metrics["byol_loss"] + step_metrics["online_cross_entropy_loss"] + \
            step_metrics["momentum_cross_entropy_loss"], step_metrics




In [6]:
log = True
if log:
    wandb.init(config = args.__dict__)

[34m[1mwandb[0m: Currently logged in as: [33mdcaustin33[0m (use `wandb login --relogin` to force relogin)


In [7]:
step = 0
scaler = torch.cuda.amp.GradScaler()
for e in range(200):
    metrics = {
        "byol_loss": 0,
        'online_cross_entropy_loss': 0,
        'momentum_cross_entropy_loss': 0,
        'online_acc1': 0,
        'target_acc1': 0,
        'online_acc5': 0,
        'target_acc5': 0,
        "train_feats_cross_entropy": 0,
        "train_feats_l1_dist": 0,
        "train_feats_l2_dist": 0,
        "train_feats_smooth_l1": 0,
        "train_feats_kl_div": 0,
        "representation_l1": 0,
        "representation_l2": 0,
        "representation_cross_entropy": 0,
        "representation_kl": 0,
        "representation_cos_sim": 0,
        "projection_l1": 0,
        "projection_l2": 0,
        "projection_cross_entropy": 0,
        "projection_kl": 0,
        "projection_cos_sim": 0,
        "momentum_projection_cos_sim": 0,
        "momentum_representation_cos_sim": 0,
        "momentum_representation_l2": 0,
        "momentum_representation_l2": 0,
        "online_projection_cos_sim": 0,
        "online_representation_cos_sim": 0,
        "online_representation_l2": 0,
        "online_projection_l2": 0,
        }
    total_loss = 0
    now = time.time()
    for i_, data in enumerate(train_loader):
        
        with torch.autocast('cuda'):
            step += 1
            data[0] = data[0].to('cuda')
            data[1][0] = data[1][0].to('cuda').to(memory_format=torch.channels_last)
            data[1][1] = data[1][1].to('cuda').to(memory_format=torch.channels_last)
            data[2] = data[2].to('cuda')
            
            
            loss, step_metrics = training_step(data, online, target)
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            schedule.step()
            scaler.update()

            update_target_params(online.network.parameters(), target.network.parameters(), .99)
            update_target_params(online.projector.parameters(), target.projector.parameters(), .99)
                
        total_loss += loss
        for key in step_metrics:
            metrics[key] += step_metrics[key]

    print(e, time.time() - now, total_loss / (i_+1), metrics['byol_loss'] / (i_+1), 
          metrics['online_acc1'] / 50000, metrics['target_acc1'] / 50000)
    if log:
        i_ += 1
        for key in metrics:
            if key[-4:-1] == 'acc':
                metrics[key] = metrics[key] / 50000
            else:
                metrics[key] = metrics[key]/i_
        wandb.log(metrics)
    



0 54.94324517250061 tensor(10.7328, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.5949, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0203, device='cuda:0') tensor(0.0180, device='cuda:0')
1 54.69605827331543 tensor(10.0428, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.1775, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0374, device='cuda:0') tensor(0.0353, device='cuda:0')
2 54.707274198532104 tensor(9.6586, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.0389, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0513, device='cuda:0') tensor(0.0488, device='cuda:0')
3 54.75790047645569 tensor(9.4311, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.0009, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0640, device='cuda:0') tensor(0.0632, device='cuda:0')
4 55.11026906967163 tensor(9.2865, device='cuda:0', grad_fn=<DivBackward0>) tensor(1.0109, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.0714, device='cuda:0') tensor(0.0716, device='cuda:0')
5 54.931285619735

42 54.20474076271057 tensor(6.2315, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6845, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3036, device='cuda:0') tensor(0.3120, device='cuda:0')
43 54.12798762321472 tensor(6.1785, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6812, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3067, device='cuda:0') tensor(0.3157, device='cuda:0')
44 54.32501578330994 tensor(6.1302, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6710, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3128, device='cuda:0') tensor(0.3208, device='cuda:0')
45 54.31333541870117 tensor(6.1005, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6739, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3140, device='cuda:0') tensor(0.3225, device='cuda:0')
46 54.365769386291504 tensor(6.0917, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6761, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.3176, device='cuda:0') tensor(0.3240, device='cuda:0')
47 54.26241755

84 54.16388511657715 tensor(5.2769, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6888, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4039, device='cuda:0') tensor(0.4136, device='cuda:0')
85 54.40089130401611 tensor(5.2374, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6846, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4089, device='cuda:0') tensor(0.4173, device='cuda:0')
86 54.21867513656616 tensor(5.2543, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6886, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4069, device='cuda:0') tensor(0.4146, device='cuda:0')
87 54.2116060256958 tensor(5.2456, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6869, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4047, device='cuda:0') tensor(0.4135, device='cuda:0')
88 54.39958906173706 tensor(5.2041, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6831, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4130, device='cuda:0') tensor(0.4209, device='cuda:0')
89 54.3789265155

126 54.13011574745178 tensor(4.7807, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6787, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4603, device='cuda:0') tensor(0.4664, device='cuda:0')
127 54.206645250320435 tensor(4.7978, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6790, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4600, device='cuda:0') tensor(0.4643, device='cuda:0')
128 54.152130365371704 tensor(4.7566, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6744, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4606, device='cuda:0') tensor(0.4666, device='cuda:0')
129 53.99315905570984 tensor(4.7634, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6759, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4597, device='cuda:0') tensor(0.4662, device='cuda:0')
130 54.014474391937256 tensor(4.7365, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6732, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.4622, device='cuda:0') tensor(0.4699, device='cuda:0')
131 53.

168 54.15423130989075 tensor(4.3884, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6564, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.5053, device='cuda:0') tensor(0.5059, device='cuda:0')
169 54.04614615440369 tensor(4.3909, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6580, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.5067, device='cuda:0') tensor(0.5083, device='cuda:0')
170 54.12036108970642 tensor(4.4069, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6605, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.5021, device='cuda:0') tensor(0.5038, device='cuda:0')
171 53.828521728515625 tensor(4.3769, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6589, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.5070, device='cuda:0') tensor(0.5079, device='cuda:0')
172 53.66307210922241 tensor(4.3770, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.6593, device='cuda:0', grad_fn=<DivBackward0>) tensor(0.5067, device='cuda:0') tensor(0.5084, device='cuda:0')
173 54.01