In [1]:
import os
os.chdir('..')

In [2]:
import copy
import torch as th
import numpy as np
import random
from torch.optim import Adam, SGD
import torch.nn as nn
from torchvision import datasets, transforms
from FLF.model.TorchResNetFactory import TorchResNetFactory

In [115]:
SEED=42
random.seed(SEED)
np.random.seed(SEED)
th.manual_seed(SEED)
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False

In [116]:
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

trfs = [
    transforms.ToTensor(),
    norm,
]

transform = transforms.Compose(trfs)
cifar100_train_ds = datasets.CIFAR10(
            "data/cifar10", download=True, transform=transform,
        )
train_loader = th.utils.data.DataLoader(
    dataset=cifar100_train_ds,
    batch_size=1024,
    pin_memory=True,
)

Files already downloaded and verified


In [117]:
lr=1e-3
model = TorchResNetFactory("group", None)().cuda()
opt = SGD(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

In [118]:
def set_model_grad(opt, model, new_state):
#     opt.zero_grad()
    new_model = copy.deepcopy(model)
    new_model.load_state_dict(new_state)
    with th.no_grad():
        for parameter, new_parameter in zip(
            model.parameters(), new_model.parameters()
        ):
            parameter.grad = (parameter.data - new_parameter.data) / lr
            # because we go to the opposite direction of the gradient
#     model_state_dict = model.state_dict()
#     new_model_state_dict = new_model.state_dict()
#     for k in dict(model.named_parameters()).keys():
#         new_model_state_dict[k] = model_state_dict[k]
#     model.load_state_dict(new_model_state_dict)
#     return model

In [119]:
for curr_epoch in range(1):
    for curr_batch, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        opt.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        opt.step()
        print(loss.item())
        if curr_batch == 25:
            break

5.007709980010986
4.740602493286133
4.506201267242432
4.292863368988037
4.074743747711182
3.9392948150634766
3.752699375152588
3.6096889972686768
3.499223470687866
3.4076027870178223
3.3114264011383057
3.2011375427246094
3.123434066772461
3.0595245361328125
3.013124465942383
2.944967269897461
2.8808624744415283
2.8453006744384766
2.8077552318573
2.7561919689178467
2.6987318992614746
2.6778147220611572
2.644211530685425
2.6113688945770264
2.6075446605682373
2.5749411582946777


In [109]:
for curr_epoch in range(1):
    for curr_batch, (data, target) in enumerate(train_loader):
        proxy_model = copy.deepcopy(model).cuda()
        data, target = data.cuda(), target.cuda()
        proxy_opt = SGD(proxy_model.parameters(), lr=lr)
        proxy_opt.zero_grad()
        output = proxy_model(data)
        loss = loss_fn(output, target)
        loss.backward()
        proxy_opt.step()
        opt.zero_grad()
        set_model_grad(opt, model, proxy_model.state_dict())
        opt.step()
        print(loss.item())
        if curr_batch == 25:
            break

5.007709980010986
4.740602493286133
4.506201267242432
4.292863368988037
4.074743747711182
3.939298152923584
3.7527012825012207
3.609694480895996
3.4992260932922363
3.4076087474823
3.3114237785339355
3.2011518478393555
3.1234350204467773
3.0595273971557617
3.0131373405456543
2.944974422454834
2.8808672428131104
2.8453073501586914
2.807766914367676
2.7561891078948975
2.6987390518188477
2.6778228282928467
2.6442205905914307
2.611361026763916
2.607542037963867
2.5749404430389404


In [64]:
# model.state_dict()

In [65]:
# model = set_model_grad(opt, model, TorchResNetFactory("group", None)().state_dict())

In [66]:
# model.state_dict()

In [67]:
next(iter(model.parameters())).grad[:5, :5]

tensor([[[[-2.9913e-03, -2.2625e-03, -2.0705e-03,  9.1495e-05, -2.1973e-04,
           -2.4900e-03, -6.5309e-03],
          [-3.8049e-03, -2.0920e-03, -2.8431e-03, -2.2389e-03, -2.9731e-03,
           -6.0943e-03, -1.0834e-02],
          [-7.6445e-03, -4.1388e-03, -9.3102e-04, -3.2318e-03, -8.6277e-03,
           -8.5048e-03, -9.7686e-03],
          [-4.2320e-03, -2.2804e-03,  5.2502e-04, -1.7948e-03, -5.1756e-03,
           -4.5959e-03, -7.2187e-03],
          [-1.2224e-03, -4.5456e-04, -7.8907e-05, -3.1534e-03, -6.5503e-03,
           -7.0068e-03, -6.7319e-03],
          [ 1.7490e-03,  2.3119e-03, -1.2318e-03, -2.2847e-03, -6.1482e-03,
           -5.5855e-03, -2.8146e-03],
          [ 2.9871e-03,  2.1173e-03, -1.0877e-03,  1.4219e-03, -1.0624e-03,
           -6.2058e-03, -4.2820e-03]],

         [[ 2.4421e-03,  1.0626e-03,  9.1221e-05,  3.6967e-03,  4.9774e-03,
            3.1004e-03, -2.0609e-03],
          [ 2.0699e-03,  1.9007e-03, -1.7900e-04,  1.4305e-03,  2.0817e-03,
          

In [68]:
# model = TorchResNetFactory("group", None)()
new_state = model.state_dict()
model = prev_model
new_model = copy.deepcopy(model)
new_model.load_state_dict(new_state)
with th.no_grad():
    for parameter, new_parameter in zip(
        model.parameters(), new_model.parameters()
    ):
        parameter.grad = (parameter.data - new_parameter.data) / 1e-3
        # because we go to the opposite direction of the gradient
# model_state_dict = model.state_dict()
# new_model_state_dict = new_model.state_dict()
# for k in dict(model.named_parameters()).keys():
#     new_model_state_dict[k] = model_state_dict[k]
# model.load_state_dict(new_model_state_dict)

In [69]:
next(iter(model.parameters())).grad[:5, :5]

tensor([[[[-2.9914e-03, -2.2631e-03, -2.0713e-03,  9.1270e-05, -2.1979e-04,
           -2.4885e-03, -6.5304e-03],
          [-3.8054e-03, -2.0936e-03, -2.8431e-03, -2.2389e-03, -2.9728e-03,
           -6.0946e-03, -1.0833e-02],
          [-7.6443e-03, -4.1388e-03, -9.3086e-04, -3.2317e-03, -8.6278e-03,
           -8.5048e-03, -9.7677e-03],
          [-4.2319e-03, -2.2799e-03,  5.2527e-04, -1.7956e-03, -5.1744e-03,
           -4.5951e-03, -7.2187e-03],
          [-1.2224e-03, -4.5449e-04, -7.9162e-05, -3.1535e-03, -6.5509e-03,
           -7.0073e-03, -6.7316e-03],
          [ 1.7490e-03,  2.3134e-03, -1.2317e-03, -2.2845e-03, -6.1467e-03,
           -5.5842e-03, -2.8146e-03],
          [ 2.9872e-03,  2.1160e-03, -1.0878e-03,  1.4231e-03, -1.0626e-03,
           -6.2063e-03, -4.2822e-03]],

         [[ 2.4419e-03,  1.0617e-03,  9.1270e-05,  3.6964e-03,  4.9775e-03,
            3.1004e-03, -2.0609e-03],
          [ 2.0699e-03,  1.8999e-03, -1.7905e-04,  1.4305e-03,  2.0824e-03,
          