In [None]:
from resnet import *
from train_util import *
from FP_layers import *
from train_classes import *
# from train_class import *

import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim


import time
import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# print(device, 'torch.distributed.is_available():', torch.distributed.is_available())
print(device)
# torch.distributed.init_process_group()

In [None]:
def nt_xent_loss(x, temperature=0.1):
    xcs = F.cosine_similarity(x[None, :, :], x[:, None, :], dim=-1)
    # assume x has shape [batch_size*2, 64]
    xcs[torch.eye(x.size(0)).bool()] = float("-inf")

    target = torch.arange(x.size(0))
    target[0::2] += 1
    target[1::2] -= 1

    ce_loss = F.cross_entropy(xcs.to(device) / torch.tensor(temperature).to(device), target.to(device), reduction="mean")
    # Standard cross entropy loss
    # only need adjacent pairs: [2k, 2k-1] and [2k-1, 2k] for k in range(batch_size/2),
    # which is ensured by "target" (some magic happens I guess)
    return ce_loss

def add_contrastive_loss(hidden,
                         hidden_norm=True,
                         temperature=1.0,
                         weights=1.0):
    """Compute loss for model.

    Args:
        hidden: hidden vector (`Tensor`) of shape (2 * bsz, dim).
        hidden_norm: whether or not to use normalization on the hidden vector.
        temperature: a `floating` number for temperature scaling.
        tpu_context: context information for tpu.
        weights: a weighting number or vector.

    Returns:
        A loss scalar.
        The logits for contrastive prediction task.
        The labels for contrastive prediction task.
    """
    # Get (normalized) hidden1 and hidden2.
    # print(hidden)
    if hidden_norm:
        hidden = F.normalize(hidden, p=2, dim=1)
    batch_size = int(hidden.shape[0] / 2)
    hidden1, hidden2 = torch.split(hidden, batch_size, dim=0)

    # Gather hidden1/hidden2 across replicas and create local labels.
    
    hidden1_large = hidden1
    # print(hidden1)
    hidden2_large = hidden2
    labels = F.one_hot(torch.arange(0, batch_size), batch_size * 2)
    masks = F.one_hot(torch.arange(0, batch_size), batch_size)

    logits_aa = torch.matmul(hidden1, hidden1_large.T) / temperature
    logits_aa = logits_aa - masks * 1
    logits_bb = torch.matmul(hidden2, hidden2_large.T) / temperature
    logits_bb = logits_bb - masks * 1
    logits_ab = torch.matmul(hidden1, hidden2_large.T) / temperature
    logits_ba = torch.matmul(hidden2, hidden1_large.T) / temperature
    loss_a = F.cross_entropy(
        labels.type(torch.DoubleTensor), torch.concat([logits_ab, logits_aa], 1))
    loss_b = F.cross_entropy(
        labels.type(torch.DoubleTensor), torch.concat([logits_ba, logits_bb], 1))
    loss = loss_a + loss_b

    return loss, logits_ab, labels

bs = 512
t = 0.1
size = 64
crit = ContrastiveLoss(batch_size=bs, temperature=t)

t1, t2 = torch.Tensor(bs, size), torch.Tensor(bs, size)
t3 = torch.Tensor(bs*2, size)

for i in range(bs):
    t1[i, :] = i*(1+torch.Tensor(np.sin(np.linspace(0, size, num=size))))
    t2[i, :] = i*(1+torch.Tensor(np.sin(np.linspace(0, size, num=size))))
t3[::2, :] = t1
t3[1::2, :] = t2

print(crit(t1, t2), nt_xent_loss(t3, temperature=t), t1.shape, add_contrastive_loss(t3)[0])
# print(world_size)
# train_w_DDP(epochs=100, batch_size=batch_size, lr=0.3*batch_size/256, reg=1e-6, world_size=world_size, log_every_n=50)
# rank, world_size, epochs, batch_size, lr, reg, head, log_every_n=50

# rank, world_size, epochs, batch_size, lr, reg, head, log_every_n=50):
# net, epochs, batch_size, lr, reg, rank, world_size, log_every_n=50
# optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
# optimizer.zero_grad()
# loss_fn(model(input), target).backward()
# optimizer.step()

In [None]:
from sklearn.decomposition import PCA
# # model = nn.parallel.DistributedDataParallel(ResNetCIFAR(head_g=head))
load_bs = 8192
trainloader, testloader = Trainer_wo_DDP.cifar_dataloader_wo_ddp(bs=load_bs, train_for_finetune=0)

bs = 4096
epoch = 1000
resnet_model_pth = "./saved_models/epoch_%d_bs_%d_lr_%g_reg_1e-06.pt" % \
                (epoch, bs, 0.3*bs/256)

lin_eval_net = LinearEvaluation(method='lin', which_device=device, resnet_model_pth=resnet_model_pth, Nbits=None, symmetric=False).to(device)

batch, labels = None, None
for i, (b, l) in enumerate(testloader):
    batch = b
    labels = l
    break
    # print(lin_eval_net(batch.to(device))[0])
_ = lin_eval_net(batch.to(device))
out_dict = {i: [] for i in range(10)}
for i in range(load_bs): 
    # out_dict[int(labels[i])].append(torch.norm(out[i]).cpu().detach().numpy())
    out_dict[int(labels[i])].append(lin_eval_net.embedding[i, :].cpu().detach().numpy())

for i in range(10):
    out_dict[i] = PCA().fit_transform(np.array(out_dict[i]))

# print(concat_dict(out_dict).shape)

fig, ax = plt.subplots(1, 1)
print(out_dict[0][0].shape)
for i in range(10):
    out_dict[i] = np.array(out_dict[i])
    ax.scatter(out_dict[i][:, 0], out_dict[i][:, 1], s=1, alpha=0.4)
plt.show()
# ax.plot(sorted(out_dict[1]))
# head = nn.Sequential(FP_Linear(64, 64, Nbits=None), nn.ReLU(True), FP_Linear(64, 64, Nbits=None))
# # head = nn.Sequential(nn.Linear(64, 64), nn.ReLU(True), nn.Linear(64, 64))
# # model = nn.DataParallel(ResNetCIFAR_mine(num_layers=50)).to(device)
# model = ResNetCIFAR(head_g=head, num_layers=50).to(device)
# # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(nn.DataParallel(ResNetCIFAR(head_g=head, num_layers=50), device_ids=[0, 1]).to(device))
# # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(nn.DataParallel(ResNetCIFAR(head_g=head, num_layers=50).to(device), device_ids=[0, 1]))
# # model = nn.DataParallel(ResNetCIFAR(head_g=head, num_layers=50)).to(device)
# # train_no_DDP(model, epochs=1000, batch_size=batch_size, lr=0.3*batch_size/256, reg=1e-6, log_every_n=20)
# train_nt_xet_class(model, epochs=1000, batch_size=batch_size, lr=0.3*batch_size/256, reg=1e-6, log_every_n=20)

# # # b = torch.rand((64, 3, 28, 28))
# # head = nn.Sequential(FP_Linear(64, 64, Nbits=None), nn.ReLU(True), FP_Linear(64, 64, Nbits=None))

# # y = model(b)

# # # xcs = F.cosine_similarity(y[None, :, :], y[:,None,:], dim=-1)
# # # ce_loss = F.cross_entropy(xcs / temperature, target, reduction="mean")
# # # diff_sum = 0
# # # for rr, cc in enumerate(range(128)):
# #     # diff_sum += xcs[rr, cc] - F.cosine_similarity(y[rr, :], y[cc, :], dim=-1)

In [None]:
# head = nn.Sequential(FP_Linear(64, 64, Nbits=None), nn.ReLU(True), FP_Linear(64, 64, Nbits=None))
# model = ResNetCIFAR(head_g=head).to(device)
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transforms.ToTensor())
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)
# b = None
# for idx, (inputs, label) in enumerate(trainloader):
#     print(inputs.shape)
#     b = augment_data(inputs)
#     break
T_max = 1000
model = ResNetCIFAR()
optimizer = LARS(model.parameters(), lr=4.8, momentum=0.9, weight_decay=1e-6, nesterov=False)
warmup_iters = 10
scheduler_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, 
                                                            total_iters=warmup_iters, verbose=False)
scheduler_after  = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, verbose=False)

lr_list = []
for i in range(T_max):
    if i > 10:
        scheduler = scheduler_after
    else:
        scheduler = scheduler_warmup
    lr_list.append(scheduler.get_lr())
    scheduler.step()
plt.plot(lr_list)

In [None]:
# idx = -2
# plt.imshow(b[idx, 0, :, :])
# plt.show()
# plt.imshow(b[idx+1, 0, :, :])
# plt.show()
crit = ContrastiveLoss(16)
# data_1 = torch.zeros((16, 3))
# for i in range(8):
#     data_1[2*i, :] = i
#     data_1[2*i+1, :] = i

data_1 = torch.zeros((16, 3))
data_2 = torch.zeros((16, 3))
for i in range(16):
    data_1[i, :] = np.sqrt(i)
    # data_2[i, :] = torch.exp(torch.arange(3))
    # data_2[i, :] = torch.arange(3)
    # data_2[i, :] = i
    data_2[i, :] = 0
crit(data_1, data_2)