In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import nbimporter
from NNModel import ECGClassifier
import pickle
import gzip

# Load Pretraining Data (PhysioNet)
with gzip.open("pretraining_data.pkl.gz", "rb") as f:
    X_train, y_train = pickle.load(f)

# Convert to PyTorch Dataset
class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X.unsqueeze(1)  # Add channel dimension
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create DataLoader
batch_size = 64
train_dataset = ECGDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Modify Model for Pretraining (Last FC Layer = 4 Classes)
class PretrainECGClassifier(ECGClassifier):
    def __init__(self):
        super(PretrainECGClassifier, self).__init__(num_classes=4)  # Change output classes to 4

# Initialize Model
model = PretrainECGClassifier()

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# He Normal Initialization
def he_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")

model.apply(he_init)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Pretraining Loop
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

# Save Pretrained Model with gzip and pickle
with gzip.open("pretrained_model.pth.gz", "wb") as f:
    pickle.dump(model.state_dict(), f)
print("Pretraining Complete. Model Saved as pretrained_model.pth.gz!")