In [133]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from models.Models import Models
from models.ClientModelStrategy import ClientModelStrategy
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [106]:
models = ClientModelStrategy.strategy_1(4)

In [107]:
models

[<function models.Models.Models.ResNet18()>,
 <function models.Models.Models.ResNet18()>,
 <function models.Models.Models.ResNet34()>,
 <function models.Models.Models.ResNet34()>]

In [108]:
model = Models.available["resnet18"]()

In [112]:
model = models[1]()

In [134]:
model = Models.available["resnet18"]()
mean, std = [0.47889522, 0.47227842, 0.43047404], [0.24205776, 0.23828046, 0.25874835]
train_transform = transforms.Compose([
    # transforms.Resize((32,32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
dataset = ImageFolder(root="dataset/cinic-10/train", transform=train_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

torch.manual_seed(0)
model = model(10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)



In [135]:
def train(model, optimizer, criterion, epochs):
    """
    Train the client model

    Args:
        num_epoch (int): number of epochs to train for
    """
    # self.model.to(self.device) 
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            # Send data and target to device
            data, target = data.to(device), target.to(device)
            
            # Zero out gradients
            optimizer.zero_grad()

            # Forward pass
            output = model(data)
            loss = criterion(output, target)

            total_loss+= loss.item()
            total_correct+= output.argmax(dim=1).eq(target).sum().item()

            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(dataloader.dataset),
                    100. * batch_idx / len(dataloader), loss.item()))
        
        writer.add_scalar("Loss", total_loss/len(dataloader), epoch)
        writer.add_scalar("Accuracy", total_correct/len(dataloader), epoch)


In [136]:
train(model, optimizer, criterion, 20)



In [100]:
writer.close()

In [6]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

test_dataset = ImageFolder(root="dataset/cinic-10/test", transform=test_transform)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

In [15]:
synthetic_dataset = ImageFolder(root="diffusion/dataset/cifar_10/", transform=test_transform)
synthetic_dataloader = DataLoader(synthetic_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

In [102]:
@torch.inference_mode()
def evaluate(model):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        count = 0
        total_accuracy = 0
        for batch_idx, (data, target) in enumerate(test_dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * len(data)
            count += len(data)
            accuracy = (output.argmax(dim=1) == target).float().mean().item()
            total_accuracy += accuracy * len(data)
        print("len(data): {}, len(dataloader): {}".format(count, len(test_dataset)))
        print("Test loss: {}".format(total_loss / count))
        print("Test accuracy: {}".format(total_accuracy / count))



In [89]:
from knowledge_distillation import SoftTarget
kd_temperature = 1
kd_alpha = 0.5
kd_epochs = 10

def knowledge_distillation(server_logits, synthetic_data=None, diffusion_seed=None):
    """
    Knowledge distillation from server to client

    Args:
        server_logits (torch.Tensor): logits from the server model
        synthetic_data (torch.utils.data.DataLoader): synthetic diffusion data - if not generated at runtime
        diffusion_seed (int): random seed for diffusion sampling - if generated at runtime
    Returns:
        torch.Tensor: logits from the client model
    """

    # Generate synthetic data if not provided
    # if synthetic_data is None:
    #     synthetic_data = generate_diffusion(diffusion_seed)

    # model.to(device)
    model.train()

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    kd_criterion = SoftTarget(kd_temperature).to(device)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(kd_epochs):
        for batch_idx, ((data, target), logit) in enumerate(zip(synthetic_data, server_logits)):

            logit = logit[0]

            data, target, logit = data.to(device), target.to(device), logit.to(device)

            optimizer.zero_grad()

            output = model(data)

            # loss = kd_alpha * kd_criterion(output, server_logits, target)
            loss = (1 - kd_alpha) * criterion(output, target) + \
                    kd_alpha * kd_criterion(output, logit)

            loss.backward()
            optimizer.step()


def generate_logit(model, synthetic_data, diffusion_seed=None):
    """
    Generate logits from the client model

    Args:
        diffusion_seed (int): random seed for diffusion sampling - if generated at runtime
    Returns:
        torch.Tensor: logits from the client model
    """

    # if synthetic_data is None:
    #     synthetic_data = generate_diffusion(diffusion_seed)

    # model.to(device)
    model.eval()

    with torch.no_grad():
        logits = []

        for batch_idx, (data, target) in enumerate(synthetic_data):
            data, target = data.to(device), target.to(device)

            output = model(data)
            logits.append(output)

    # return torch.cat(logits, dim=0) 
    return torch.cat(logits).detach().cpu()


In [16]:
generate_logit(synthetic_dataloader)

tensor([[-3.7136e-01,  5.5566e+00, -1.9034e+00,  ..., -2.2747e+00,
         -2.3779e-01,  5.6667e+00],
        [-2.7515e-01,  4.7884e+00, -1.8918e+00,  ..., -2.1806e+00,
          9.2094e-01,  5.4238e+00],
        [-1.9213e+00, -3.7828e+00,  1.2176e+00,  ...,  3.2784e+00,
         -3.6679e+00, -3.2022e+00],
        ...,
        [ 2.7700e+00, -1.7160e-01, -1.2474e+00,  ..., -2.5995e+00,
          5.7527e+00,  7.1538e-01],
        [ 3.9338e+00, -1.6628e+00,  1.2267e-03,  ..., -4.3972e+00,
          7.8782e+00, -7.3171e-01],
        [-2.6094e+00, -4.4847e+00,  2.9344e+00,  ..., -1.0425e+00,
         -3.0934e+00, -4.3059e+00]], device='cuda:0')

In [19]:
model_2 = Models.available["resnet34"]()
torch.manual_seed(0)
model_2 = model_2(10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_2 = model_2.to(device)
criterion_2 = nn.CrossEntropyLoss()
optimizer_2 = optim.SGD(model_2.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

In [21]:
train(model_2, optimizer_2, criterion_2, 10)



In [103]:
evaluate(model_2)

len(data): 90000, len(dataloader): 90000
Test loss: 1.265936116557651
Test accuracy: 0.5461444444444444


In [70]:
logits = generate_logit(model_2, synthetic_dataloader)

In [34]:
from torch.utils.data import TensorDataset

In [72]:
logit_dataloader = DataLoader(TensorDataset(logits), batch_size=128)

In [90]:
knowledge_distillation(logit_dataloader, synthetic_dataloader)

In [91]:
evaluate(model)

Test loss: 1.6232765786488852
Test accuracy: 0.5146666666666667


In [47]:
for elem in enumerate(logit_dataloader):
    print(len(elem))
    break

2


In [84]:
for elem in logit_dataloader:
    vari, = elem
    print(vari.shape)
    break


torch.Size([128, 10])


In [56]:
logits.shape

torch.Size([2500, 10])

In [68]:
TensorDataset(logits)

<torch.utils.data.dataset.TensorDataset at 0x20f48263210>

In [71]:
logits.shape

torch.Size([2500, 10])

In [101]:
foo = 2
print(f"Client_{foo:02}")

Client_02


In [120]:
seed = torch.randint(0, 100000, (1,)).item()
print(seed)
torch.manual_seed(seed)

80591


<torch._C.Generator at 0x20f0f53b990>

In [118]:
torch.manual_seed(seed).seed()

153691923685100

In [125]:
torch.seed()

166425389026300

In [122]:
import random
new_seed = random.randint(0, 100000)
print(new_seed)
torch.manual_seed(new_seed)

6640


<torch._C.Generator at 0x20f0f53b990>

In [123]:
torch.seed()

166404888985300

In [124]:
torch.manual_seed(new_seed)

<torch._C.Generator at 0x20f0f53b990>

In [126]:
import os
os.makedirs("checkpoints/test", exist_ok=True)

In [132]:
import time
timestr = time.strftime("%Y%m%d-%H%M%S")
os.makedirs(f"checkpoints/{timestr}", exist_ok=True)

In [128]:
timestr

'20230711-170452'

In [129]:
print(f"/checkpoints/{timestr}")

/checkpoints/20230711-170452
