In [6]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import datetime
from models import model
from datasets import dosDataLoader

In [None]:
dos_file_path = '../../Datasets/dos64/'
dataset = dosDataLoader.StandardizingDosDataset(dos_file_path)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = model.dosClassifier(input_dim=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

checkpoint_path = "checkpoints/model" + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".pth"
torch.save(model.state_dict(), checkpoint_path)
print(f"Model saved to {checkpoint_path}")


cuda
Epoch 1, Loss: 0.3573
Epoch 2, Loss: 0.3522
Epoch 3, Loss: 0.3516
Epoch 4, Loss: 0.3513
Epoch 5, Loss: 0.3510
Model saved to checkpoints/model20250523175651.pth


In [None]:
torch.load("checkpoints/model2025-05-23-18-02-35.pth")


OrderedDict([('fc1.weight',
              tensor([[-2.3452e+00,  2.8434e-01,  8.1569e-03],
                      [ 1.1547e-02, -5.3314e-02,  6.1722e-04],
                      [-6.7643e-02, -7.8720e-03, -1.2526e-02],
                      [ 2.2410e-01,  8.5369e-01,  1.5063e+00],
                      [ 1.9760e-01,  1.0724e+00,  9.0551e-02],
                      [ 8.0630e-02, -1.4307e-01, -1.5058e-01],
                      [ 2.3735e-01, -1.1827e+00, -4.0119e-02],
                      [ 1.8415e-01,  9.5504e-01, -2.6107e-02],
                      [-1.3456e-03,  1.6716e-04, -7.4186e-04],
                      [ 1.6722e-01, -5.1292e-01, -1.3069e+00],
                      [-1.2514e-01, -1.0157e-02,  6.8127e-02],
                      [ 1.6109e-01, -7.3267e-02,  1.0489e+00],
                      [ 1.8601e-01,  4.4830e-01,  1.2883e+00],
                      [-1.0353e-02,  4.5239e-04,  9.5033e-03],
                      [ 1.9973e-01, -1.0534e+00,  1.3873e-01],
                      [ 1.4