In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import gzip
from sklearn.model_selection import train_test_split
from NNModel import ECGClassifier

# Load MIT-BIH Data
with gzip.open("mitbih_beats.pkl.gz", "rb") as f:
    X_mit, y_mit = pickle.load(f)


X_train, X_test, y_train, y_test = train_test_split(X_mit, y_mit, test_size=0.2, random_state=42, stratify=y_mit)

# Dataset
class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X.unsqueeze(1)  # (B, 1, L)
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# DataLoader
batch_size = 250
train_loader = torch.utils.data.DataLoader(ECGDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(ECGDataset(X_test, y_test), batch_size=batch_size, shuffle=False)

# Model Init
model = ECGClassifier(num_classes=5)  # MIT-BIH has 5 beat classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load pretrained weights (optional)
with gzip.open("pretrained_model.pth.gz", "rb") as f:
    state_dict = pickle.load(f)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc2')}

model.load_state_dict(state_dict, strict=False)
model.fc = nn.Linear(model.fc.in_features, 5).to(device)  # Update FC layer

# Optimizer: Adam (with specified betas)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))

# Loss Function: CrossEntropy
criterion = nn.CrossEntropyLoss()

# Training Loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {running_loss/len(train_loader):.4f}")
