In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from scripts.dataset import ProteinDataset
from scripts.model import SiameseNetwork

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

dataset = ProteinDataset("data/processed/structures.npz")
loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = SiameseNetwork().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

criterion = nn.CosineEmbeddingLoss()

for epoch in range(10):
    total_loss = 0
    for x1, x2, y in loader:
        x1, x2 = x1.to(DEVICE), x2.to(DEVICE)
        y = y.to(DEVICE) * 2 - 1

        z1, z2 = model(x1, x2)

        loss = criterion(z1, z2, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss/len(loader):.4f}")

torch.save(model.state_dict(), "model.pt")


Epoch 1 | Loss: 0.4483
Epoch 2 | Loss: 0.3328
Epoch 3 | Loss: 0.3310
Epoch 4 | Loss: 0.3469
Epoch 5 | Loss: 0.3345
Epoch 6 | Loss: 0.3604
Epoch 7 | Loss: 0.3469
Epoch 8 | Loss: 0.3242
Epoch 9 | Loss: 0.3295
Epoch 10 | Loss: 0.3584
