In [27]:
!jupyter nbconvert --to script BYOL.ipynb
from BYOL import *
import torch
import time
from pretrain_dataloader import *
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torch.nn.functional as F
torch.backends.cudnn.enabled = True
from lars import LARSWrapper
import copy
import wandb
import geoopt

[NbConvertApp] Converting notebook BYOL.ipynb to script
[NbConvertApp] Writing 8063 bytes to BYOL.py


In [28]:
#build the networks

class euclidean_network(nn.Module):
    
    def __init__(self, input_size = (3, 32, 32), classes = 100, coarse_classes = 20):
        super().__init__()
        self.network = ResNet18()
        example = torch.ones(input_size)
        self.feature_dim = (self.network(example.unsqueeze(0))).shape[1]
        self.classifier = nn.Linear(self.feature_dim, classes)
        self.coarse_classifier = nn.Linear(self.feature_dim, coarse_classes)
        self.h_classifier = None
    
    def forward(self, x):
        representation = self.network(x)
        logits = self.classifier(representation.detach())
        coarse_logits = self.coarse_classifier(representation.detach())
        
        out = {'Representation': representation,
              "logits": logits,
              "coarse_logits": coarse_logits} 
        return out

In [29]:
#returns necessary default arguments, training loader and networks
args = return_default_args()
args.batch_size = 2
args.lr = .01
args.classifier_lr = .001

train_loader = prepare_cifar_train_loader(args)
network = euclidean_network()



#gives us the mapping from fine labels to the coarse labels
fine_to_coarse = fine_to_coarse_dict()

Files already downloaded and verified


In [30]:
params = [  {"name": "backbone", "params": network.network.parameters()},
            {
                "name": "classifier",
                "params": network.classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            },
            {
                "name": "coarse_classifier",
                "params": network.coarse_classifier.parameters(),
                "lr": args.classifier_lr,
                "weight_decay": args.classifier_weight_decay,
            }]


#opt = geoopt.optim.RiemannianAdam(params, lr=args.lr, weight_decay=args.wd, stabilize=10)
opt = torch.optim.AdamW(params, lr = args.lr, weight_decay = args.wd)
schedule = LinearWarmupCosineAnnealingLR(
                    opt,
                    warmup_epochs= 10 * 195,
                    max_epochs= 195 * 200,
                    warmup_start_lr=3e-05,
                    eta_min=0)

In [31]:
def training_step(data, network, fine_to_coarse = fine_to_coarse):
    step_metrics = {}
    
    labels2 = []
    for i in data[2]:
        labels2.append(fine_to_coarse[i.item()])
    labels2 = torch.Tensor(labels2)
    labels2 = labels2.long().to(device)
    
    
    
    output = network(data[1][0])
    
    #cross entropy loss
    
    step_metrics["online_cross_entropy_loss"] = F.cross_entropy(output['logits'], data[2], ignore_index=-1)
    step_metrics["coarse_online_cross_entropy_loss"] = F.cross_entropy(output['coarse_logits'], labels2, ignore_index=-1)
    

    if network.h_classifier:
        step_metrics["online_hyperbolic_cross_entropy_loss"] = F.cross_entropy(output['h_logits'], data[2], ignore_index=-1)
        step_metrics["coarse_online_hyperbolic_cross_entropy_loss"] = F.cross_entropy(output['h_coarse_logits'], labels2, ignore_index=-1)
    
    #accuracy of predictions
    _, predicted = torch.max(output['logits'], 1)
    step_metrics["online_acc1"] = (predicted == data[2]).sum()
    
    _, pred = output['logits'].topk(5)
    data2 = data[2].unsqueeze(1).expand_as(pred)
    step_metrics["online_acc5"] = (data2 == pred).any(dim = 1).sum()
    
    #accuracy of predictions
    _, predicted = torch.max(output['coarse_logits'], 1)
    step_metrics["coarse_online_acc1"] = (predicted == labels2).sum()
    
    _, pred = output['coarse_logits'].topk(5)
    labels2_ = labels2.unsqueeze(1).expand_as(pred)
    step_metrics["coarse_online_acc5"] = (labels2_ == pred).any(dim = 1).sum()
    
    if network.h_classifier:
        #accuracy of predictions
        _, predicted = torch.max(output['h_logits'], 1)
        step_metrics["online_hyperbolic_acc1"] = (predicted == data[2]).sum()

        _, pred = output['h_logits'].topk(5)
        data[2] = data[2].unsqueeze(1).expand_as(pred)
        step_metrics["online_hyperbolic_acc5"] = (data[2] == pred).any(dim = 1).sum()

        #accuracy of predictions
        _, predicted = torch.max(output['h_coarse_logits'], 1)
        step_metrics["coarse_online_hyperbolic_acc1"] = (predicted == labels2).sum()

        _, pred = output['h_coarse_logits'].topk(5)
        labels2 = labels2.unsqueeze(1).expand_as(pred)
        step_metrics["coarse_online_hyperbolic_acc5"] = (labels2 == pred).any(dim = 1).sum()
                   
        
    loss = step_metrics["online_cross_entropy_loss"] + step_metrics["coarse_online_cross_entropy_loss"]
    
    if network.h_classifier:
        loss += step_metrics["online_hyperbolic_cross_entropy_loss"] + step_metrics["coarse_online_hyperbolic_cross_entropy_loss"]
        
    return loss, step_metrics




In [32]:
"""log = True
name = 'Supervised Euclidean CIFAR-10'
if log:
    wandb.init(config = args.__dict__, name = name, project = 'hyperbolic_byol')"""

"log = True\nname = 'Supervised Euclidean CIFAR-10'\nif log:\n    wandb.init(config = args.__dict__, name = name, project = 'hyperbolic_byol')"

In [36]:
step = 0
scaler = torch.cuda.amp.GradScaler()
if network.h_classifier:
    print('Running Hyperbolic version of BYOL')
    
for e in range(200):
    metrics = {
        'online_cross_entropy_loss': 0,
        'coarse_online_cross_entropy_loss': 0,
        'online_acc1': 0,
        'online_acc5': 0,
        'coarse_online_acc1': 0,
        'coarse_online_acc5': 0}
    
    if network.h_classifier:
        metrics['online_hyperbolic_cross_entropy_loss'] = 0
        metrics['coarse_online_hyperbolic_cross_entropy_loss'] = 0
        
        metrics['online_hyperbolic_acc1'] = 0
        metrics['online_hyperbolic_acc5'] = 0
        metrics['coarse_online_hyperbolic_acc1'] = 0
        metrics['coarse_online_hyperbolic_acc5'] = 0
        
    
    
    total_loss = 0
    now = time.time()
    for i_, data in enumerate(train_loader):
        
        with torch.autocast(device):
            step += 1
            data[0] = data[0].to(device)
            data[1][0] = data[1][0].to(device).to(memory_format=torch.channels_last)
            data[1][1] = data[1][1].to(device).to(memory_format=torch.channels_last)
            data[2] = data[2].to(device)
            
            
            loss, step_metrics = training_step(data, network)
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            schedule.step()
            scaler.update()
                
        total_loss += loss
        for key in step_metrics:
            metrics[key] += step_metrics[key]
        
    
        print(e, round(time.time() - now, 2), total_loss.item() / (i_+1), metrics['online_acc1'].item() / 50000, 
          metrics['coarse_online_acc1'].item() / 50000)
    if log:
        i_ += 1
        for key in metrics:
            if key[-4:-1] == 'acc':
                metrics[key] = metrics[key] / 50000
            else:
                metrics[key] = metrics[key]/i_
        wandb.log(metrics)
    checkpoint = {
    'online': network.state_dict(),
    'epoch': e,
    'optimizer': opt.state_dict()
    }
    torch.save(checkpoint, 'checkpoint.pt')
    
    

0 1.72 7.84375 0.0 0.0
0 3.27 8.125 0.0 0.0
0 4.85 7.833333333333333 0.0 0.0
0 6.65 7.75 0.0 2e-05
0 8.55 7.55 2e-05 2e-05
0 10.15 7.5 2e-05 2e-05
0 12.0 7.642857142857143 2e-05 2e-05
0 13.65 7.59375 2e-05 4e-05
0 15.23 7.611111111111111 2e-05 4e-05
0 16.87 7.7 2e-05 4e-05
0 18.47 7.7727272727272725 2e-05 4e-05
0 20.11 7.708333333333333 2e-05 4e-05
0 21.65 7.6923076923076925 2e-05 6e-05
0 23.48 7.642857142857143 2e-05 8e-05
0 25.1 7.666666666666667 2e-05 8e-05
0 26.66 7.6875 2e-05 8e-05
0 28.34 7.647058823529412 2e-05 8e-05
0 29.92 7.666666666666667 2e-05 8e-05
0 31.52 7.684210526315789 2e-05 8e-05
0 33.19 7.65 2e-05 0.0001
0 34.88 7.666666666666667 2e-05 0.0001
0 36.46 7.636363636363637 2e-05 0.0001
0 38.03 7.695652173913044 2e-05 0.0001
0 39.6 7.75 2e-05 0.0001
0 41.27 7.76 2e-05 0.0001
0 43.14 7.769230769230769 2e-05 0.0001
0 44.74 7.7407407407407405 2e-05 0.0001
0 46.31 7.714285714285714 2e-05 0.00012
0 47.9 7.724137931034483 2e-05 0.00012
0 49.48 7.7 4e-05 0.00012
0 51.04 7.677419

Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/Derek/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 40

KeyboardInterrupt: 