In [37]:
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.models.resnet import resnet50, Bottleneck
import torch.fx as fx

import torch
import torch._dynamo
import numpy as np
import pickle
from math import ceil
from typing import List

0: Load the data

In [29]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

data_batch_1 = unpickle('/home/lyy/data/cifar-10-batches-py/data_batch_1')
data_batch_2 = unpickle('/home/lyy/data/cifar-10-batches-py/data_batch_2')
data_batch_3 = unpickle('/home/lyy/data/cifar-10-batches-py/data_batch_3')
data_batch_4 = unpickle('/home/lyy/data/cifar-10-batches-py/data_batch_4')
data_batch_5 = unpickle('/home/lyy/data/cifar-10-batches-py/data_batch_5')
test_batch = unpickle('/home/lyy/data/cifar-10-batches-py/test_batch')


X_train_data = np.concatenate((data_batch_1['data'], data_batch_2['data'], data_batch_3['data'], data_batch_4['data'], data_batch_5['data']))
X_train_labels = data_batch_1['labels'] + data_batch_2['labels'] + data_batch_3['labels'] + data_batch_4['labels'] + data_batch_5['labels']
X_test_data = test_batch['data']
X_test_labels = test_batch['labels']

1: preprocess the data

In [30]:
X_train_data = X_train_data.reshape(len(X_train_data),3,32,32)

batch_size = 128
split_data_list = np.array_split(X_train_data,ceil(X_train_data.shape[0]/batch_size), axis=0)
how_many_batches = len(split_data_list)
split_labels_list = np.array_split(X_train_labels,ceil(len(X_train_labels)/batch_size), axis=0)

2 import model and define optimizer

In [32]:
net = resnet50(10).to('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)



In [33]:
def segment(inputs, labels, losses):
    inputs = split_data_list[i]
    labels = split_labels_list[i]

    inputs = torch.from_numpy(inputs)
    inputs = inputs.float()
    inputs = inputs.to('cuda')

    labels = torch.from_numpy(labels)
    labels = labels.to('cuda')

    outputs = net(inputs)
    loss = criterion(outputs, labels)
    losses.append(loss.item())

    loss.backward()
    optimizer.step()
    
    optimizer.zero_grad()

    outputs = net(inputs)
    loss = criterion(outputs, labels)
    losses.append(loss.item())

    loss.backward()
    optimizer.step()

    return loss.item()

4: start to train

In [34]:
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called wiht FX graph:")

    gm.graph.print_tabular()
    return gm.forward

torch._dynamo.reset()

opt_model_segment = torch.compile(segment, backend=custom_backend)

In [35]:
EPOCHS = 1 # actual:200
for epoch in range(EPOCHS):
    losses = []
    running_loss = 0

    for i in range(how_many_batches):
        # move to the segment funciton
        loss = segment(split_data_list[i], split_labels_list[i], losses)
        # prove of concept: can be put into the loop and print everything
        # print(opt_model_segment(split_data_list[i], split_labels_list[i], losses))

        running_loss += loss
        
        if i%100 == 0 and i > 0:
            print(f'Loss [{epoch+1}, {i}](epoch, minibatch): ', running_loss / 100)
            running_loss = 0.0


    avg_loss = sum(losses)/len(losses)
    scheduler.step(avg_loss)

print('Training Done')

custom backend called wiht FX graph:
opcode       name       target     args               kwargs
-----------  ---------  ---------  -----------------  --------
placeholder  l_stack0_  L_stack0_  ()                 {}
call_method  float_1    float      (l_stack0_,)       {}
call_method  to         to         (float_1, 'cuda')  {}
output       output     output     ((to,),)           {}
custom backend called wiht FX graph:
opcode         name                                        target                                                      args                                                                             kwargs
-------------  ------------------------------------------  ----------------------------------------------------------  -------------------------------------------------------------------------------  --------
placeholder    l_stack0_                                   L_stack0_                                                   ()                                     

KeyboardInterrupt: 

5: custom backend

In [16]:
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called wiht FX graph:")

    gm.graph.print_tabular()
    return gm.forward

In [20]:
torch._dynamo.reset()

6: print out the graph of the torch.nn.module(aka resnet50)

In [27]:
#gm = torch.fx.symbolic_trace(m)
gm = torch.fx.symbolic_trace(net)
# call and print the graph
gm.graph.print_tabular()

opcode         name                   target                                                      args                                   kwargs
-------------  ---------------------  ----------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                           ()                                     {}
call_module    conv1                  conv1                                                       (x,)                                   {}
call_module    bn1                    bn1                                                         (conv1,)                               {}
call_module    relu                   relu                                                        (bn1,)                                 {}
call_module    maxpool                maxpool                                                     (relu,)                                {}
call_modul

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [24]:
# TODO: wouldn't work, do not know why yet
opt_model = torch.compile(net, backend=custom_backend)
print(opt_model(split_data_list[0]))

TypeError: conv2d() received an invalid combination of arguments - got (numpy.ndarray, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)


7: print out the graph of segment(only 1 slice of data)

In [40]:
opt_model_segment = torch.compile(segment, backend=custom_backend)
print(opt_model_segment(split_data_list[0], split_labels_list[0], losses=[]))



custom backend called wiht FX graph:
opcode       name       target     args               kwargs
-----------  ---------  ---------  -----------------  --------
placeholder  l_stack0_  L_stack0_  ()                 {}
call_method  float_1    float      (l_stack0_,)       {}
call_method  to         to         (float_1, 'cuda')  {}
output       output     output     ((to,),)           {}
custom backend called wiht FX graph:
opcode         name                                        target                                                      args                                                                             kwargs
-------------  ------------------------------------------  ----------------------------------------------------------  -------------------------------------------------------------------------------  --------
placeholder    l_stack0_                                   L_stack0_                                                   ()                                     

In [None]:
try:
    torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

model_exp = torch._dynamo.export(init_model(), generate_data(16)[0])
print(opt_model_segment(split_data_list[0], split_labels_list[0], losses=[]))