In [1]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

# Config

In [2]:
batch_size = 64
num_classes = 10  # CIFAR-10 has 10 classes
learning_rate = 0.001
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Load CIFAR-10 Dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet-50 expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 34491078.48it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


# Model

In [4]:
%%capture
# Load Pre-trained ResNet-50 Model
model = models.resnet50(pretrained=True)
model = model.to(device)

# Remove the classifier to extract features
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])  # Remove the last FC layer
feature_extractor.eval()

# Split Dataset

In [5]:
def prepare_data(client_id, dataset):
    # Define label subsets for each client
    client_labels = {
        1: [0, 1, 2, 3, 4],         # Airplane, Automobile, Bird, Cat, Deer
        2: [1, 2, 3, 4, 5, 6, 7, 8],  # Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck
        3: list(range(10)),          # All classes
    }
    
    labels_for_client = client_labels[client_id]

    # Filter dataset for the client's labels
    indices = [i for i, (_, label) in enumerate(dataset) if label in labels_for_client]
    client_dataset = Subset(dataset, indices)
    return client_dataset

In [6]:
client1_data = prepare_data (1 , train_dataset)
client2_data = prepare_data (2 , train_dataset)
client3_data = prepare_data (3 , train_dataset)

In [7]:
client1_dataloader = DataLoader(client1_data, batch_size=batch_size, shuffle=True)
client2_dataloader = DataLoader(client2_data, batch_size=batch_size, shuffle=True)
client3_dataloader = DataLoader(client3_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Extract features

In [8]:
def extract_features (dataloader) :
    features = []
    label_list = []
    with torch.no_grad():
        for images, labels in tqdm(dataloader):  # CIFAR-10 labels are not needed for feature extraction
            images = images.to(device)
            outputs = feature_extractor(images)  # Extract features
            outputs = outputs.view(outputs.size(0), -1)  # Flatten the features
            features.append(outputs.cpu())  # Move to CPU and store
            label_list.append (labels.cpu ())
            #print (label_list)
# Concatenate all features and save
    features = torch.cat(features, dim=0)
    label_list = torch.cat (label_list, dim = 0)
    return features, label_list

In [9]:
test_image_feature, test_image_label = extract_features (test_dataloader)
client1_image_feature, client1_image_label = extract_features (client1_dataloader)
client2_image_feature, client2_image_label = extract_features (client2_dataloader)
client3_image_feature, client3_image_label = extract_features (client3_dataloader)

100%|██████████| 157/157 [00:29<00:00,  5.41it/s]
100%|██████████| 391/391 [01:11<00:00,  5.48it/s]
100%|██████████| 625/625 [01:54<00:00,  5.47it/s]
100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


In [10]:
torch.save(test_image_feature, "/kaggle/working/test_image_feature.pt")
torch.save(test_image_label, "/kaggle/working/test_image_label.pt")

torch.save(client1_image_feature, "/kaggle/working/client1_image_feature.pt")
torch.save(client1_image_label, "/kaggle/working/client1_image_label.pt")

torch.save(client2_image_feature, "/kaggle/working/client2_image_feature.pt")
torch.save(client2_image_label, "/kaggle/working/client2_image_label.pt")

torch.save(client3_image_feature, "/kaggle/working/client3_image_feature.pt")
torch.save(client3_image_label, "/kaggle/working/client3_image_label.pt")

In [11]:
test_image_feature = torch.load( "/kaggle/working/test_image_feature.pt" , weights_only = True)
test_image_label = torch.load( "/kaggle/working/test_image_label.pt", weights_only = True)
train_image_feature = torch.load( "/kaggle/working/client3_image_feature.pt", weights_only = True)
train_image_label = torch.load( "/kaggle/working/client3_image_label.pt", weights_only = True)

# Softmax Model

In [26]:
# Define the Softmax Regression Model
class SoftmaxRegression(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(SoftmaxRegression, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

In [27]:
# Initialize the model
input_dim = train_image_feature.size(1)  # Number of input features
model = SoftmaxRegression(input_dim, 10).to(device)

In [28]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [29]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score
import torchvision

# Split dataset into training and validation sets
#labels = torch.tensor(train_dataset.targets)  # CIFAR-10 labels are integers 0-9
X_train, X_val, y_train, y_val = train_test_split(
    train_image_feature, train_image_label, test_size=0.2, random_state=42
)
#X_train, y_train, X_val, y_val = train_image_feature , train_image_label , test_image_feature, test_image_label
# Create PyTorch Datasets and DataLoaders
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

test_dataset =  TensorDataset(test_image_feature, test_image_label)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [30]:
# Training Loop
for epoch in range(num_epochs):
    model.train()
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        # Forward pass
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch )

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_predictions, val_labels = [], []
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            val_predictions.extend(predicted.cpu().numpy())
            val_labels.extend(y_batch.numpy())

    # Calculate validation accuracy
    val_accuracy = accuracy_score(val_labels, val_predictions)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Validation Accuracy: {val_accuracy:.4f}")

print("Training Complete.")

Epoch [1/10], Loss: 0.4077, Validation Accuracy: 0.8763
Epoch [2/10], Loss: 0.4238, Validation Accuracy: 0.8878
Epoch [3/10], Loss: 0.4363, Validation Accuracy: 0.8992
Epoch [4/10], Loss: 0.3741, Validation Accuracy: 0.9007
Epoch [5/10], Loss: 0.2617, Validation Accuracy: 0.9016
Epoch [6/10], Loss: 0.3068, Validation Accuracy: 0.9062
Epoch [7/10], Loss: 0.2046, Validation Accuracy: 0.9030
Epoch [8/10], Loss: 0.1271, Validation Accuracy: 0.9040
Epoch [9/10], Loss: 0.4109, Validation Accuracy: 0.8988
Epoch [10/10], Loss: 0.2647, Validation Accuracy: 0.8937
Training Complete.


In [31]:
model.eval()
val_predictions, val_labels = [], []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        outputs = model(X_batch)
        _, predicted = torch.max(outputs, 1)
        val_predictions.extend(predicted.cpu().numpy())
        val_labels.extend(y_batch.numpy())

# Calculate validation accuracy
val_accuracy = accuracy_score(val_labels, val_predictions)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Validation Accuracy: {val_accuracy:.4f}")


Epoch [10/10], Loss: 0.2647, Validation Accuracy: 0.8871
