In [1]:
!jupyter nbconvert --to script BYOL.ipynb
import BYOL
import torch
import time
from pretrain_dataloader import *
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torch.nn.functional as F
torch.backends.cudnn.enabled = True
import copy
import wandb
import geoopt

[NbConvertApp] Converting notebook BYOL.ipynb to script
[NbConvertApp] Writing 8063 bytes to BYOL.py


  stdout_func(


In [2]:
#returns necessary default arguments, training loader and networks
args = return_default_args()
args.batch_size = 256
train_loader = prepare_cifar_train_loader(args)
online = BYOL.euclidean_BYOL_module().to(device).to(memory_format=torch.channels_last)
target = copy.deepcopy(online).to(device).to(memory_format=torch.channels_last)

#gives us the mapping from fine labels to the coarse labels
fine_to_coarse = fine_to_coarse_dict()

Files already downloaded and verified


In [3]:
params = [
            {"name": "backbone", "params": online.network.parameters()},
            {
                "name": "classifier",
                "params": online.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "coarse_classifier",
                "params": online.coarse_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "momentum_classifier",
                "params": target.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "coarse_momentum_classifier",
                "params": target.coarse_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {'params': online.projector.parameters()},
            {'params':online.predictor.parameters(),}
        ]

#in use if we are using a hyperbolic classifier
if online.h_classifier:
    print('Using Hyperbolic Classifier')
    params.append({
                "name": "hyperbolic_classifier",
                "params": online.h_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            })
    params.append({
                "name": "hyperbolic_coarse_classifier",
                "params": online.h_coarse_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            })
    params.append({
                "name": "momentum_hyperbolic_classifier",
                "params": target.h_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            })
    params.append({
                "name": "momentum_hyperbolic_coarse_classifier",
                "params": target.h_coarse_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            })
    params.append({'params': online.h_representation.parameters()})
    params.append({'params': online.h_representation_coarse.parameters()})
    params.append({'params': online.hyperbolic_projector.parameters()})

args.lr = .001
args.wd = 1.5e-5

#opt = geoopt.optim.RiemannianAdam(params, lr=args.lr, weight_decay=args.wd, stabilize=10)
opt = torch.optim.AdamW(params, lr = args.lr, weight_decay = args.wd)
schedule = LinearWarmupCosineAnnealingLR(
                    opt,
                    warmup_epochs= 10 * 195,
                    max_epochs= 195 * 200,
                    warmup_start_lr=3e-05,
                    eta_min=0)

In [4]:
def training_step(data, online, target, fine_to_coarse = fine_to_coarse):
    step_metrics = {}
    
    labels2 = []
    for i in data[2]:
        labels2.append(fine_to_coarse[i.item()])
    labels2 = torch.Tensor(labels2)
    labels2 = labels2.long().to(device)
    
    
    
    online1 = online(data[1][0])
    online2 = online(data[1][1])
    #target1 = target.momentum_forward(data[1][0])
    #target2 = target.momentum_forward(data[1][1])
    
    target1 = online1
    target2 = online2
    
    #byol_loss
    step_metrics["byol_loss"] = byol_loss_func(online1['p'], target2['z'])
    step_metrics["byol_loss"] += byol_loss_func(online2['p'], target1['z'])
    
    #cross entropy loss
    
    step_metrics["online_cross_entropy_loss"] = F.cross_entropy(online1['logits'], data[2], ignore_index=-1)
    step_metrics["momentum_cross_entropy_loss"] = F.cross_entropy(target1['logits'], data[2], ignore_index=-1)
    step_metrics["coarse_online_cross_entropy_loss"] = F.cross_entropy(online1['coarse_logits'], labels2, ignore_index=-1)
    step_metrics["coarse_momentum_cross_entropy_loss"] = F.cross_entropy(target1['coarse_logits'], labels2, ignore_index=-1)
    

    if online.h_classifier:
        step_metrics["online_hyperbolic_cross_entropy_loss"] = F.cross_entropy(online1['h_logits'], data[2], ignore_index=-1)
        step_metrics["momentum_hyperbolic_cross_entropy_loss"] = F.cross_entropy(target1['h_logits'], data[2], ignore_index=-1)
        step_metrics["coarse_online_hyperbolic_cross_entropy_loss"] = F.cross_entropy(online1['h_coarse_logits'], labels2, ignore_index=-1)
        step_metrics["coarse_momentum_hyperbolic_cross_entropy_loss"] = F.cross_entropy(target1['h_coarse_logits'], labels2, ignore_index=-1)
    
    #accuracy of predictions
    _, predicted = torch.max(online1['logits'], 1)
    step_metrics["online_acc1"] = (predicted == data[2]).sum()
    _, predicted = torch.max(target1['logits'], 1)
    step_metrics["target_acc1"] = (predicted == data[2]).sum()
    
    _, pred = online1['logits'].topk(5)
    data2 = data[2].unsqueeze(1).expand_as(pred)
    step_metrics["online_acc5"] = (data2 == pred).any(dim = 1).sum()
    _, pred = target1['logits'].topk(5)
    step_metrics["target_acc5"] = (data2 == pred).any(dim = 1).sum()
    
    #accuracy of predictions
    _, predicted = torch.max(online1['coarse_logits'], 1)
    step_metrics["coarse_online_acc1"] = (predicted == labels2).sum()
    _, predicted = torch.max(target1['coarse_logits'], 1)
    step_metrics["coarse_target_acc1"] = (predicted == labels2).sum()
    
    _, pred = online1['coarse_logits'].topk(5)
    labels2_ = labels2.unsqueeze(1).expand_as(pred)
    step_metrics["coarse_online_acc5"] = (labels2_ == pred).any(dim = 1).sum()
    _, pred = target1['coarse_logits'].topk(5)
    step_metrics["coarse_target_acc5"] = (labels2_ == pred).any(dim = 1).sum()
    
    if online.h_classifier:
        #accuracy of predictions
        _, predicted = torch.max(online1['h_logits'], 1)
        step_metrics["online_hyperbolic_acc1"] = (predicted == data[2]).sum()
        _, predicted = torch.max(target1['h_logits'], 1)
        step_metrics["target_hyperbolic_acc1"] = (predicted == data[2]).sum()

        _, pred = online1['h_logits'].topk(5)
        data[2] = data[2].unsqueeze(1).expand_as(pred)
        step_metrics["online_hyperbolic_acc5"] = (data[2] == pred).any(dim = 1).sum()
        _, pred = target1['h_logits'].topk(5)
        step_metrics["target_hyperbolic_acc5"] = (data[2] == pred).any(dim = 1).sum()

        #accuracy of predictions
        _, predicted = torch.max(online1['h_coarse_logits'], 1)
        step_metrics["coarse_online_hyperbolic_acc1"] = (predicted == labels2).sum()
        _, predicted = torch.max(target1['h_coarse_logits'], 1)
        step_metrics["coarse_target_hyperbolic_acc1"] = (predicted == labels2).sum()

        _, pred = online1['h_coarse_logits'].topk(5)
        labels2 = labels2.unsqueeze(1).expand_as(pred)
        step_metrics["coarse_online_hyperbolic_acc5"] = (labels2 == pred).any(dim = 1).sum()
        _, pred = target1['h_coarse_logits'].topk(5)
        step_metrics["coarse_target_hyperbolic_acc5"] = (labels2 == pred).any(dim = 1).sum()
                   
        
    
    
    #metrics to track
    with torch.no_grad():
        online1p_softmax = F.softmax(online1['p'], dim = 1)
        online2p_softmax = F.softmax(online2['p'], dim = 1)
        online1rep_softmax = F.softmax(online1['Representation'], dim = 1)
        online2rep_softmax = F.softmax(online2['Representation'], dim = 1)
        online1z_softmax = F.softmax(online1['z'], dim = 1)
        online2z_softmax = F.softmax(online2['z'], dim = 1)
        
        target1z_softmax = F.softmax(target1['z'], dim = 1)
        target2z_softmax = F.softmax(target1['z'], dim = 1)
        target1rep_softmax = F.softmax(target1['Representation'], dim = 1)
        target2rep_softmax = F.softmax(target2['Representation'], dim = 1)
        
        
        cross_entropy = F.cross_entropy(online1p_softmax, target2z_softmax, ignore_index=-1) + F.cross_entropy(online2p_softmax, target1z_softmax, ignore_index=-1)
        l1_dist = F.l1_loss(online1['p'], target2['z']) + F.l1_loss(online2['p'], target1['z'])
        l2_dist = F.mse_loss(online1['p'], target2['z']) + F.mse_loss(online2['p'], target1['z'])
        smooth_l1 = F.smooth_l1_loss(online1['p'], target2['z']) + F.smooth_l1_loss(online2['p'], target1['z'])
        kl_div = F.kl_div(online1p_softmax, target2z_softmax) + F.kl_div(online2p_softmax, target1z_softmax)
        
        representation_l1 = F.l1_loss(online1['Representation'], target2['Representation']) + F.l1_loss(online2['Representation'], target1['Representation'])
        representation_l2 = F.mse_loss(online1['Representation'], target2['Representation']) + F.mse_loss(online2['Representation'], target1['Representation'])
        representation_cross_entropy = F.cross_entropy(online1rep_softmax, target2rep_softmax, ignore_index=-1) + F.cross_entropy(online2rep_softmax, target1rep_softmax, ignore_index=-1)
        representation_kl = F.kl_div(online1rep_softmax, target2rep_softmax) + F.kl_div(online2rep_softmax, target1rep_softmax)
        representation_cos_sim = F.cosine_similarity(online1['Representation'], target2['Representation']) + F.cosine_similarity(online2['Representation'], target1['Representation'])

        projection_l1 = F.l1_loss(online1['z'], target2['z']) + F.l1_loss(online2['z'], target1['z'])
        projection_l2 = F.mse_loss(online1['z'], target2['z']) + F.mse_loss(online2['z'], target1['z'])
        projection_cross_entropy = F.cross_entropy(online1z_softmax, target2z_softmax, ignore_index=-1) + F.cross_entropy(online2z_softmax, target1z_softmax, ignore_index=-1)
        projection_kl = F.kl_div(online1z_softmax, target2z_softmax) + F.kl_div(online2z_softmax, target1z_softmax)
        projection_cos_sim = F.cosine_similarity(online1['z'], target2['z']) + F.cosine_similarity(online2['z'], target1['z'])

        momentum_projection_cos_sim = F.cosine_similarity(target1['z'], target2['z'])
        momentum_representation_cos_sim = F.cosine_similarity(target1['Representation'], target2['Representation'])
        momentum_representation_l2 =  F.mse_loss(target1['Representation'], target2['Representation'])
        momentum_projection_l2 =  F.mse_loss(target1['z'], target2['z'])

        online_projection_cos_sim = F.cosine_similarity(online1['z'], online2['z'])
        online_representation_cos_sim = F.cosine_similarity(online1['Representation'], online2['Representation'])
        online_representation_l2 =  F.mse_loss(online1['Representation'], online2['Representation'])
        online_projection_l2 =  F.mse_loss(online1['z'], online2['z'])
        
        online_representation_std = torch.mean(torch.std(online1['Representation']))
        online_projection_std = torch.mean(torch.std(online1['z']))
        online_prediction_std = torch.mean(torch.std(online1['p']))
        target_representation_std = torch.mean(torch.std(target1['Representation']))
        target_projection_std = torch.mean(torch.std(target1['z']))

        step_metrics.update({
            "train_feats_cross_entropy": cross_entropy,
            "train_feats_l1_dist": l1_dist,
            "train_feats_l2_dist": l2_dist,
            "train_feats_smooth_l1": smooth_l1,
            "train_feats_kl_div": kl_div,
            "representation_l1": representation_l1,
            "representation_l2": representation_l1,
            "representation_cross_entropy": representation_cross_entropy,
            "representation_kl": representation_kl,
            "representation_cos_sim": representation_cos_sim.mean(),
            "projection_l1": projection_l1,
            "projection_l2": projection_l2,
            "projection_cross_entropy": projection_cross_entropy,
            "projection_kl": projection_kl,
            "projection_cos_sim": projection_cos_sim.mean(),
            "momentum_projection_cos_sim": momentum_projection_cos_sim.mean(),
            "momentum_representation_cos_sim": momentum_representation_cos_sim.mean(),
            "momentum_representation_l2": momentum_representation_l2,
            "momentum_representation_l2": momentum_projection_l2,
            "online_projection_cos_sim": online_projection_cos_sim.mean(),
            "online_representation_cos_sim": online_representation_cos_sim.mean(),
            "online_representation_l2": online_representation_l2,
            "online_projection_l2": online_projection_l2,
            "online_representation_std": online_representation_std,
            "online_projection_std": online_projection_std,
            "online_prediction_std": online_prediction_std,
            "target_representation_std": target_representation_std,
            "target_projection_std": target_projection_std,
            })
        
    loss = step_metrics["byol_loss"] + step_metrics["online_cross_entropy_loss"] + \
            step_metrics["momentum_cross_entropy_loss"] + step_metrics["coarse_online_cross_entropy_loss"] \
            + step_metrics["coarse_momentum_cross_entropy_loss"]
    
    if online.h_classifier:
        loss += step_metrics["online_hyperbolic_cross_entropy_loss"] + step_metrics["momentum_hyperbolic_cross_entropy_loss"] + \
                step_metrics["coarse_online_hyperbolic_cross_entropy_loss"] + step_metrics["coarse_momentum_hyperbolic_cross_entropy_loss"]
        
    return loss, step_metrics




In [5]:
log = True
name = 'BYOL Hyperbolic No Linear Before H-Classifier'
if log:
    wandb.init(config = args.__dict__, name = name, project = 'hyperbolic_byol')

[34m[1mwandb[0m: Currently logged in as: [33mdcaustin33[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
step = 0
scaler = torch.cuda.amp.GradScaler()
if online.h_classifier:
    print('Running Hyperbolic version of BYOL')
    
for e in range(200):
    metrics = {
        "byol_loss": 0,
        'online_cross_entropy_loss': 0,
        'momentum_cross_entropy_loss': 0,
        'coarse_online_cross_entropy_loss': 0,
        'coarse_momentum_cross_entropy_loss': 0,
        'online_acc1': 0,
        'target_acc1': 0,
        'online_acc5': 0,
        'target_acc5': 0,
        'coarse_online_acc1': 0,
        'coarse_target_acc1': 0,
        'coarse_online_acc5': 0,
        'coarse_target_acc5': 0,
        "train_feats_cross_entropy": 0,
        "train_feats_l1_dist": 0,
        "train_feats_l2_dist": 0,
        "train_feats_smooth_l1": 0,
        "train_feats_kl_div": 0,
        "representation_l1": 0,
        "representation_l2": 0,
        "representation_cross_entropy": 0,
        "representation_kl": 0,
        "representation_cos_sim": 0,
        "projection_l1": 0,
        "projection_l2": 0,
        "projection_cross_entropy": 0,
        "projection_kl": 0,
        "projection_cos_sim": 0,
        "momentum_projection_cos_sim": 0,
        "momentum_representation_cos_sim": 0,
        "momentum_representation_l2": 0,
        "momentum_representation_l2": 0,
        "online_projection_cos_sim": 0,
        "online_representation_cos_sim": 0,
        "online_representation_l2": 0,
        "online_projection_l2": 0,
        "online_representation_std": 0,
        "online_projection_std": 0,
        "online_prediction_std": 0,
        "target_representation_std": 0,
        "target_projection_std": 0,
        }
    
    if online.h_classifier:
        metrics['online_hyperbolic_cross_entropy_loss'] = 0
        metrics['momentum_hyperbolic_cross_entropy_loss'] = 0
        metrics['coarse_online_hyperbolic_cross_entropy_loss'] = 0
        metrics['coarse_momentum_hyperbolic_cross_entropy_loss'] = 0
        
        metrics['online_hyperbolic_acc1'] = 0
        metrics['target_hyperbolic_acc1'] = 0
        metrics['online_hyperbolic_acc5'] = 0
        metrics['target_hyperbolic_acc5'] = 0
        metrics['coarse_online_hyperbolic_acc1'] = 0
        metrics['coarse_target_hyperbolic_acc1'] = 0
        metrics['coarse_online_hyperbolic_acc5'] = 0
        metrics['coarse_target_hyperbolic_acc5'] = 0
        
    
    
    total_loss = 0
    now = time.time()
    for i_, data in enumerate(train_loader):
        
        with torch.autocast(device):
            step += 1
            data[0] = data[0].to(device)
            data[1][0] = data[1][0].to(device).to(memory_format=torch.channels_last)
            data[1][1] = data[1][1].to(device).to(memory_format=torch.channels_last)
            data[2] = data[2].to(device)
            
            
            loss, step_metrics = training_step(data, online, target)
            opt.zero_grad()
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(online.parameters(), 1.0)
            scaler.step(opt)
            schedule.step()
            scaler.update()
            update_target_params(online.network.parameters(), target.network.parameters(), .99)
            update_target_params(online.projector.parameters(), target.projector.parameters(), .99)
                
        total_loss += loss
        for key in step_metrics:
            metrics[key] += step_metrics[key]
    
    print(e, round(time.time() - now, 2), total_loss.item() / (i_+1), metrics['byol_loss'].item() / (i_+1), 
          metrics['online_acc1'].item() / 50000, metrics['target_acc1'].item() / 50000, metrics['coarse_online_acc1'].item() / 50000)
    if log:
        i_ += 1
        for key in metrics:
            if key[-4:-1] == 'acc':
                metrics[key] = metrics[key] / 50000
            else:
                metrics[key] = metrics[key]/i_
        wandb.log(metrics)
    checkpoint = {
    'online': online.state_dict(),
    'target': target.state_dict(),
    'epoch': e,
    'optimizer': opt.state_dict()
    }
    torch.save(checkpoint, 'checkpoint.pt')

0 43.83 15.93349859775641 1.376688482822516 0.03812 0.03812 0.11544
1 43.08 15.673069411057693 1.2910663311298076 0.04546 0.04546 0.13044
2 43.39 15.549864783653845 1.2780930739182692 0.05118 0.05118 0.14302
3 43.32 15.51539087540064 1.268421427408854 0.0579 0.0579 0.15096
4 43.41 15.468130258413462 1.227275124574319 0.06006 0.06006 0.1557
5 43.4 15.378087439903846 1.1650713798327323 0.06488 0.06488 0.1643
6 43.52 15.24981219951923 1.1086284930889423 0.07216 0.07216 0.1681
7 43.12 15.364014923878205 1.0953783084184696 0.07 0.07 0.16836


In [None]:
#.01 LR
'''
0 54.98 16.301710236378206 1.3990733611278046 0.02588 0.02302 0.09818
1 53.84 15.126978165064102 0.9789630596454327 0.05156 0.05114 0.14322
2 54.21 14.806014623397436 0.931276370317508 0.0646 0.06716 0.1626
3 54.14 14.602803235176282 0.8721234443860176 0.07308 0.0747 0.17544
4 54.28 14.423293519631411 0.8290852864583333 0.08204 0.08572 0.18862
5 54.18 14.334297375801283 0.8047475179036458 0.09012 0.09398 0.19864
6 54.06 14.161805138221155 0.792535165640024 0.09704 0.10206 0.2092
7 54.09 14.14975335536859 0.7905127892127404 0.1046 0.10946 0.21614
8 54.09 13.999675731169871 0.7640458327073317 0.1108 0.11748 0.22088
9 53.94 13.861602313701923 0.7397880358573717 0.1194 0.12544 0.23064
10 54.06 13.705500050080127 0.7379991580278445 0.12708 0.13374 0.24136
11 53.99 13.454622395833333 0.7214473626552484 0.1366 0.14262 0.2535
12 54.06 13.020966045673077 0.7090653639573318 0.15088 0.1568 0.26708
13 53.97 12.838525390625 0.7102927183493589 0.1607 0.16826 0.27686'''

#lr = .1 got nana in the loss

In [None]:
torch.save(online.state_dict(), name)
torch.save(target.state_dict(), 'target' + name)

In [8]:
online(data[1][0])

{'Representation': tensor([[1.0902, 0.7197, 0.7068,  ..., 0.6889, 0.6700, 0.7505],
         [0.6945, 0.7999, 0.8918,  ..., 1.0819, 0.9804, 0.8990]],
        grad_fn=<ReshapeAliasBackward0>),
 'logits': tensor([[-1.0778e-01,  6.7949e-01,  2.9006e-01, -5.4419e-01,  8.8811e-01,
          -6.3720e-01, -3.0231e-01,  4.0207e-01,  6.0091e-01, -1.1101e+00,
          -2.7813e-01,  4.0618e-01,  5.7357e-01,  8.0544e-01,  3.3681e-01,
          -3.4905e-01,  5.9667e-01, -7.6058e-01, -2.2123e-01, -1.1967e+00,
          -2.1913e-01,  3.6329e-01, -4.0279e-01, -3.3254e-01, -3.8375e-02,
          -1.3133e-01,  5.0848e-01,  1.9677e-01,  2.6188e-01, -5.9601e-01,
           4.1806e-01,  1.0126e-01,  4.5833e-02, -1.4168e-01,  1.9744e-01,
           4.5370e-01,  4.6671e-01,  8.5084e-01,  2.1330e-01, -9.1768e-01,
          -1.0758e+00,  7.0109e-01, -1.5537e-01,  6.2724e-02,  1.3873e+00,
          -7.1104e-02, -7.2754e-02,  1.0333e+00, -4.2682e-01, -1.0542e-01,
           1.2579e-01, -1.3975e-01,  4.2783e-01, 