In [4]:
from utils import *

#Task 1

##Part 2
The chosen model is ConvNet

In [5]:
from ptflops import get_model_complexity_info
import matplotlib.pyplot as plt
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Args:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

def train(net, trainloader, criterion, optimizer,device=device):
    net.train()
    net.to(device)
    #create args.device 
    args = Args()
    args.device = device
    progress_bar = tqdm(range(20), total=20, desc=f'Training - No data available', leave=True)

    for i in progress_bar:
        loss_avg, acc_avg, num_exp = 0, 0, 0
        for i_batch, datum in enumerate(trainloader):
            img = datum[0].float().to(args.device)
            lab = datum[1].long().to(args.device)
            n_b = lab.shape[0]

            output = net(img)
            loss = criterion(output, lab)
            acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
            loss_avg += loss.item()*n_b
            acc_avg += acc
            num_exp += n_b

            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss_avg /= num_exp
        acc_avg /= num_exp
        progress_bar.set_description(f'Training - Loss: {loss_avg:.4f} - Accuracy: {acc_avg:.4f}')
    return net

def test(net, testloader, criterion, optimizer,device=device):
    net.eval()
    net.to(device)
    args = Args()
    args.device = device
    loss_avg, acc_avg, num_exp = 0, 0, 0
    for i_batch, datum in enumerate(testloader):
        img = datum[0].float().to(args.device)
        lab = datum[1].long().to(args.device)
        n_b = lab.shape[0]

        output = net(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b


    loss_avg /= num_exp
    acc_avg /= num_exp
    print(f'Accuracy of the network on the test images: {acc_avg*100}%')

def count_flops(net, channel, im_size):
    param = (channel, im_size[0], im_size[1])
    flops, params = get_model_complexity_info(net, param, as_strings=True, print_per_layer_stat=False)
    print(f'FLOPs: {flops}')
    print(f'Params: {params}')

In [3]:
#MNIST Dataset
data_path = "./Project A/data"

#load the dataset
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset("MNIST", data_path)

#define the model
net = get_network('ConvNetD3', channel, num_classes, im_size)
net.to(device)

#define the loss function
criterion = nn.CrossEntropyLoss()
#define the optimizer
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
#define the learning rate scheduler
#train the model
#create trainloader
trainloader = torch.utils.data.DataLoader(dst_train, batch_size=64, shuffle=True)
net = train(net, trainloader, criterion, optimizer)

Training - Loss: 0.0116 - Accuracy: 0.9976: 100%|██████████| 20/20 [02:55<00:00,  8.77s/it]


In [4]:
#test the model, report accuracy and FLOPs
#testloader = torch.utils.data.DataLoader(dst_test, batch_size=64, shuffle=False)
test(net, testloader, criterion, optimizer, device)
#report the FLOPs
count_flops(net, channel, im_size) 

Accuracy of the network on the test images: 99.4%
FLOPs: 49.59 MMac
Params: 317.71 k


In [6]:
#MHIST Dataset
#load the dataset
data_path = "mhist_dataset"
#print files in the directory
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset("MHIST", data_path)

#define the model
net = get_network('ConvNetD7', channel, num_classes, im_size)
net.to(device) 

#define the loss function
criterion = nn.CrossEntropyLoss()
#define the optimizer
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
#define the learning rate scheduler
#train the model
#create trainloader
trainloader = torch.utils.data.DataLoader(dst_train, batch_size=64, shuffle=True)
net=train(net, trainloader, criterion, optimizer)
            
            

Training - No data available:   0%|          | 0/20 [00:00<?, ?it/s]

Training - Loss: 0.5546 - Accuracy: 0.7145:  10%|█         | 2/20 [00:09<01:22,  4.60s/it]

In [6]:
#test the model, report accuracy and FLOPs
test(net, testloader, criterion, optimizer, "cpu")
#report the FLOPs
count_flops(net, channel, im_size) 

Accuracy of the network on the test images: 80.4503582395087%
FLOPs: 2.7 GMac
Params: 891.14 k


### Distillation function
Found here https://github.com/DataDistillation/DataDAM/blob/main/main_DataDAM.py

In [3]:
from DataDAM import DataDAM

#MNIST Dataset
data_path = "./Project A/data"

#load the dataset
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset("MNIST", data_path)

#define the model
net = get_network('ConvNetD3', channel, num_classes, im_size)
net.to(device)

#define the loss function
criterion = nn.CrossEntropyLoss()
#define the optimizer
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

distillator = DataDAM(net, dst_train, IPC=10, num_classes=num_classes, im_size=im_size, channels=channel, K=100, T=10, eta_S=0.1, zeta_S=1, eta_theta=0.01, zeta_theta=1, lambda_mmd=0.01, device=device)
distillator.initialize_synthetic_dataset()
distillator.train()

#get the condensed dataset
condensed_dataset = distillator.get_condensed_dataset()

#print the condensed images
plt.imshow(condensed_dataset[0][0].reshape(28, 28), cmap='gray')
plt.show()



  class_indices = torch.where(torch.tensor(self.real_dataset.targets) == i)[0]


TypeError: optimizer can only optimize Tensors, but one of the params is tuple