In [2]:
import torch, torch.nn as nn, random
from torch.utils.data import DataLoader, Subset
from torchaudio.datasets import SPEECHCOMMANDS
import torchaudio
from myapp.task import Net, collate_fn

# ==========================================
# CONFIG
# ==========================================
DATA_DIR = "./data"
SAMPLE_LIMIT = 2000
BATCH_SIZE = 16
LR = 3e-4
EPOCHS = 3
SAVE_PATH = "warmup_pretrained.pt"

# ==========================================
# 1) LOAD PARTIAL (2000 SAMPLES) DATASET
# ==========================================
print("ðŸ“¥ Loading 2000 samples for warm-up...")

dataset = SPEECHCOMMANDS(DATA_DIR, download=False)
indices = list(range(len(dataset)))
random.shuffle(indices)
indices = indices[:SAMPLE_LIMIT]

subset = Subset(dataset, indices)

trainloader = DataLoader(
    subset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

print(f"âœ… Warmup dataset size: {len(trainloader.dataset)} samples")

# ==========================================
# 2) INITIALIZE MODEL (Wav2Vec2 frozen + classifier)
# ==========================================
model = Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=LR)

# ==========================================
# 3) TRAIN THE CLASSIFIER ONLY (WARM-UP TRAINING)
# ==========================================
print("\nðŸš€ Starting warm-up training on server...\n")

for epoch in range(EPOCHS):
    total_loss = 0
    for waveforms, labels in trainloader:
        waveforms = waveforms.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(waveforms)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(trainloader)
    print(f"[Warmup] Epoch {epoch+1}/{EPOCHS} â€” Loss: {avg_loss:.4f}")

# ==========================================
# 4) SAVE PRETRAINED MODEL
# ==========================================
torch.save(model.state_dict(), SAVE_PATH)
print(f"\nðŸ”¥ Warm-up complete! Model saved â†’ {SAVE_PATH}")


ðŸ“¥ Loading 2000 samples for warm-up...
âœ… Warmup dataset size: 2000 samples

ðŸš€ Starting warm-up training on server...

[Warmup] Epoch 1/3 â€” Loss: 2.7685
[Warmup] Epoch 2/3 â€” Loss: 2.4212
[Warmup] Epoch 3/3 â€” Loss: 2.1356

ðŸ”¥ Warm-up complete! Model saved â†’ warmup_pretrained.pt


In [3]:
# ==========================================
# 3.5) QUICK ACCURACY CHECK (optional)
# ==========================================
print("\nðŸ“Š Running quick accuracy test...")

# Take 200 samples from SAME dataset for evaluation
eval_indices = indices[200:400]
eval_subset = Subset(dataset, eval_indices)

evalloader = DataLoader(
    eval_subset,
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn,
)

correct = 0
total = 0

model.eval()
with torch.no_grad():
    for waveforms, labels in evalloader:
        waveforms = waveforms.to(device)
        labels = labels.to(device)

        outputs = model(waveforms)
        preds = outputs.argmax(dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

warmup_acc = correct / total
print(f"ðŸ”¥ Warm-up Accuracy: {warmup_acc*100:.2f}%")
model.train()



ðŸ“Š Running quick accuracy test...
ðŸ”¥ Warm-up Accuracy: 38.51%


Net(
  (classifier): Sequential(
    (0): Linear(in_features=768, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=20, bias=True)
  )
)

In [5]:
from myapp.task import Net
import flwr as fl
import torch

net = Net()
net.load_state_dict(torch.load("warmup_pretrained.pt", map_location="cpu"))
net.eval()

# Convert into Flower parameters
initial_parameters = fl.common.ndarrays_to_parameters(
    [val.cpu().numpy() for _, val in net.state_dict().items()]
)

# Save to disk for server
import pickle
with open("initial_params.pkl", "wb") as f:
    pickle.dump(initial_parameters, f)

print("Saved initial_parameters.pkl")


Saved initial_parameters.pkl
