In [1]:
import torch
import torchvision
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from utils.tools import dotdict
from driver.driver import ABC_Driver
from torch_geometric_temporal import METRLADatasetLoader
from other_model.other_model import make_default_model
import atd2022
torch.cuda.is_available()
torch.cuda.set_device(2)

In [2]:
cifar10_args = dotdict()

cifar10_args.name = 'cifar10'
cifar10_args.train_batch_size = 128
cifar10_args.predict_batch_size = 128
cifar10_args.device = ['cuda:2','cuda:3']

cifar10_args.train_epochs = 250
cifar10_args.lr = 0.01
cifar10_args.criterion = 'CE'
cifar10_args.optimizer = 'AdamW'
cifar10_args.scheduler = 'OneCycle'
cifar10_args.attack = {'fgsm':(0.031,), 'pgd':(0.031,1,20)}

activation = 'relu'
input_channel = 3
knpp = [48,96,144,240,336,432,528]

cifar10_args.layers=[
    ('cnn2d', ((input_channel, knpp[0], (3,3), 1, 1, 1, 1), 1, None, None, activation, False)),
    ('atrc2d', ((knpp[0], knpp[1], (3,3), 1, 1, 1, knpp[0]), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[1], knpp[2], (3,3), 1, 1, 1, knpp[0]), 1, 'first', (2,2), activation, True)),
    ('atrc2d', ((knpp[2], knpp[3], (3,3), 1, 1, 1, knpp[0]), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[3], knpp[4], (3,3), 1, 1, 1, knpp[0]), 1, 'first', (2,2), activation, True)),
    ('atrc2d', ((knpp[4], knpp[5], (3,3), 1, 1, 1, knpp[0]), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[5], knpp[6], (3,3), 1, 1, 1, knpp[0]), 1, 'first', (2,2), activation, True)),
#     ('cnn2d', ((knpp[-1], knpp[-1], (3,3), 1, 0, 1, knpp[0]), 1, None, None, activation, False)),
#     ('cnn2d', ((knpp[-1], knpp[-1], (2,2), 1, 0, 1, knpp[0]), 1, None, None, False, False)),
    ('adptavgpool', (1,1)),
    ('linear', (knpp[-1], 10, (1,2,3)))
]

In [3]:
cifar100_args = dotdict()

cifar100_args.name = 'cifar100'
cifar100_args.train_batch_size = 100
cifar100_args.predict_batch_size = 100
cifar100_args.device = ['cuda:2','cuda:3']

cifar100_args.train_epochs = 250
cifar100_args.lr = 0.2
cifar100_args.criterion = 'CE'
cifar100_args.optimizer = 'SGD'
cifar100_args.scheduler = 'multistep'
cifar100_args.attack = {'fgsm':(0.005,), 'pgd':(0.005,0.1,20)}

activation = 'relu'
input_channel = 3
knpp = [40, 80, 120, 160, 200, 240, 280, 320, 360, 400]
# knpp = [48,96,144,240,336,432,528]
groups=40

cifar100_args.layers=[
    ('cnn2d', ((input_channel, knpp[0], (3,3), 1, 1, 1, 1), 1, None, None, activation, False)),
    ('atrc2d', ((knpp[0], knpp[1], (3,3), 1, 1, 1, groups), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[1], knpp[2], (3,3), 1, 1, 1, groups), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[2], knpp[3], (3,3), 1, 1, 1, groups), 1, 'first', (2,2), activation, True)),
    ('atrc2d', ((knpp[3], knpp[4], (3,3), 1, 1, 1, groups), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[4], knpp[5], (3,3), 1, 1, 1, groups), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[5], knpp[6], (3,3), 1, 1, 1, groups), 1, 'first', (2,2), activation, True)),
    ('atrc2d', ((knpp[6], knpp[7], (3,3), 1, 1, 1, groups), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[7], knpp[8], (3,3), 1, 1, 1, groups), 1, None, None, activation, True)),
    ('atrc2d', ((knpp[8], knpp[9], (3,3), 1, 1, 1, groups), 1, 'first', (2,2), activation, True)),
#     ('cnn2d', ((knpp[-1], knpp[-1], (3,3), 1, 0, 1, knpp[0]), 1, None, None, activation, False)),
#     ('cnn2d', ((knpp[-1], knpp[-1], (2,2), 1, 0, 1, knpp[0]), 1, None, None, False, False)),
    ('adptavgpool', (1,1)), 
    ('linear', (knpp[-1], 100, (1,2,3)))
]

In [None]:
driver = ABC_Driver(cifar100_args, None, record_path=None, if_hash=False)
driver.train()

Use: ['cuda:2', 'cuda:3']
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
add record: 03/25/2023 20:00
epoch: 0, train_loss: 4.3045, test_metric: 0.1112, time: 670.1726574897766
epoch: 1, train_loss: 3.9027, test_metric: 0.1811, time: 667.0814135074615
epoch: 2, train_loss: 3.4672, test_metric: 0.2806, time: 666.9277930259705
epoch: 3, train_loss: 3.0334, test_metric: 0.3591, time: 667.0463027954102
epoch: 4, train_loss: 2.7571, test_metric: 0.4102, time: 666.8434376716614
epoch: 5, train_loss: 2.5253, test_metric: 0.4478, time: 667.1520984172821
epoch: 6, train_loss: 2.3595, test_metric: 0.482, time: 666.8790338039398
epoch: 7, train_loss: 2.2335, test_metric: 0.4911, time: 666.9742844104767
epoch: 8, train_loss: 2.1404, test_metric: 0.5055, time: 667.0379712581635
epoch: 9, train_loss: 2.0652, test_metric: 0.5295, time: 667.0900840759277
epoch: 10, train_loss: 1.9928, test_metric: 0.5378, time: 666.827228307724
epoch: 

In [3]:
# torch.save(driver.model.state_dict(), "save/CIFAR100_ABC_2023_03_16.pt")
# driver.model.load_state_dict(torch.load("save/CIFAR100_ABC_2023_03_13.pt"))