In [2]:
!pip install torch

Collecting torch
  Downloading torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# !pip install h5py

Collecting h5py
  Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m68.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: h5py
Successfully installed h5py-3.12.1


In [2]:
#dataloading and preprocessing
import h5py
with h5py.File('tactmat.h5', 'r') as dataset:
    samples = dataset['samples'][:]
    materials = dataset['materials'][:]
    materials = [m.decode('utf-8') for m in materials]
train_samples = samples[:, :80,...]  # (36, 80, 1000, 4, 4)
test_samples = samples[:, 80:,...]  # (36, 20, 1000, 4, 4)

# Flatten data
train_samples = train_samples.reshape((36*80, 1000, 16))  # (2880, 1000, 16)
test_samples = test_samples.reshape((36*20, 1000, 16))  # (720, 1000, 16)

# Create labels
train_labels = np.repeat(np.arange(36), 80)  # (2880,)
test_labels = np.repeat(np.arange(36), 20)  # (720,)

# Shuffle
indices = torch.randperm(len(train_samples))
train_samples = train_samples[indices]
train_labels = train_labels[indices]

# One-hot encoding not needed for pytorch CrossEntropyLoss
train_samples = torch.tensor(train_samples.reshape(2880, 1, 1000, 16), dtype=torch.float32)
test_samples = torch.tensor(test_samples.reshape(720, 1, 1000, 16), dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)

In [6]:
# Model Definition
class MCTestTacNet(nn.Module):
    def __init__(self):
        super(MCTestTacNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(15, 5))
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=(10, 1), stride=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(15, 5))
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=(10, 1), stride=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(15, 5))
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(kernel_size=(10, 1), stride=1)
        self.dropout = nn.Dropout(0.8)

        # Dummy forward pass to calculate the flattened size
        self.flattened_size = self._get_flatten_size()

        self.fc1 = nn.Linear(self.flattened_size, 512)
        self.bn4 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 36)

    def _get_flatten_size(self): # helper function for dynamic sizing of linear layer
        with torch.no_grad():
            x = torch.zeros(1, 1, 1000, 16)  # example
            x = self.pool3(self.bn3(self.conv3(self.pool2(self.bn2(self.conv2(self.pool1(self.bn1(self.conv1(x)))))))))
            return x.view(1, -1).size(1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.bn4(self.fc1(self.dropout(x))))
        x = self.fc2(x)  # no softmax bc cross entropy loss includes it
        return x

# Instantiate Model
model = MCTestTacNet()
model.to(device)


# Training func
def train_model(model, train_samples, train_labels,val_samples,val_labels,  num_epochs=10, batch_size=32, learning_rate=0.001):
    train_dataset = TensorDataset(train_samples, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    val_dataset = TensorDataset(val_samples, val_labels)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)    
    
    train_samples = train_samples.to(device)
    train_labels = train_labels.to(device)
    
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=170, gamma=0.1)
    
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0

        for inputs, labels in tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", unit="batch"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs) #forward pass
            loss = criterion(outputs, labels) #loss
            loss.backward() #backward pass
            optimizer.step() #update params
            running_loss += loss.item() #calc running loss

        average_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss:.4f}")
        
        scheduler.step()
        
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f"Validation Loss: {val_loss:.4f}")

    return model, val_loss
    
    print("Training complete!")

#run training
trained_model, val_error = train_model(model, train_samples, train_labels, test_samples, test_labels, num_epochs=200, batch_size=32, learning_rate=0.001)

# Save the model and validation error
torch.save(trained_model.state_dict(), "trained_model.pth")
with open("val_error.txt", "w") as f:
    f.write(str(val_error))

Epoch [1/200]: 100%|██████████| 90/90 [00:23<00:00,  3.77batch/s]


Epoch [1/200], Loss: 2.6790


Epoch [2/200]: 100%|██████████| 90/90 [00:24<00:00,  3.66batch/s]


Epoch [2/200], Loss: 1.6281


Epoch [3/200]: 100%|██████████| 90/90 [00:24<00:00,  3.64batch/s]


Epoch [3/200], Loss: 1.1287


Epoch [4/200]: 100%|██████████| 90/90 [00:24<00:00,  3.70batch/s]


Epoch [4/200], Loss: 0.8313


Epoch [5/200]: 100%|██████████| 90/90 [00:24<00:00,  3.70batch/s]


Epoch [5/200], Loss: 0.6316


Epoch [6/200]: 100%|██████████| 90/90 [00:24<00:00,  3.66batch/s]


Epoch [6/200], Loss: 0.4273


Epoch [7/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [7/200], Loss: 0.2722


Epoch [8/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [8/200], Loss: 0.2249


Epoch [9/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [9/200], Loss: 0.1578


Epoch [10/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [10/200], Loss: 0.0861


Epoch [11/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [11/200], Loss: 0.0556


Epoch [12/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [12/200], Loss: 0.0485


Epoch [13/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [13/200], Loss: 0.0334


Epoch [14/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [14/200], Loss: 0.0354


Epoch [15/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [15/200], Loss: 0.0405


Epoch [16/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [16/200], Loss: 0.0311


Epoch [17/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [17/200], Loss: 0.0320


Epoch [18/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [18/200], Loss: 0.0226


Epoch [19/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [19/200], Loss: 0.0373


Epoch [20/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [20/200], Loss: 0.0697


Epoch [21/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [21/200], Loss: 0.0484


Epoch [22/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [22/200], Loss: 0.0334


Epoch [23/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [23/200], Loss: 0.0449


Epoch [24/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [24/200], Loss: 0.0348


Epoch [25/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [25/200], Loss: 0.0332


Epoch [26/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [26/200], Loss: 0.0345


Epoch [27/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [27/200], Loss: 0.0329


Epoch [28/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [28/200], Loss: 0.0206


Epoch [29/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [29/200], Loss: 0.0206


Epoch [30/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [30/200], Loss: 0.0289


Epoch [31/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [31/200], Loss: 0.0186


Epoch [32/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [32/200], Loss: 0.0257


Epoch [33/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [33/200], Loss: 0.0608


Epoch [34/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [34/200], Loss: 0.0432


Epoch [35/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [35/200], Loss: 0.0249


Epoch [36/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [36/200], Loss: 0.0105


Epoch [37/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [37/200], Loss: 0.0110


Epoch [38/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [38/200], Loss: 0.0073


Epoch [39/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [39/200], Loss: 0.0170


Epoch [40/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [40/200], Loss: 0.0520


Epoch [41/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [41/200], Loss: 0.0460


Epoch [42/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [42/200], Loss: 0.0421


Epoch [43/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [43/200], Loss: 0.0553


Epoch [44/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [44/200], Loss: 0.0282


Epoch [45/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [45/200], Loss: 0.0208


Epoch [46/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [46/200], Loss: 0.0111


Epoch [47/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [47/200], Loss: 0.0058


Epoch [48/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [48/200], Loss: 0.0144


Epoch [49/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [49/200], Loss: 0.0220


Epoch [50/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [50/200], Loss: 0.0470


Epoch [51/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [51/200], Loss: 0.0571


Epoch [52/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [52/200], Loss: 0.0371


Epoch [53/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [53/200], Loss: 0.0161


Epoch [54/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [54/200], Loss: 0.0080


Epoch [55/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [55/200], Loss: 0.0033


Epoch [56/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [56/200], Loss: 0.0042


Epoch [57/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [57/200], Loss: 0.0024


Epoch [58/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [58/200], Loss: 0.0055


Epoch [59/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [59/200], Loss: 0.0022


Epoch [60/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [60/200], Loss: 0.0040


Epoch [61/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [61/200], Loss: 0.0081


Epoch [62/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [62/200], Loss: 0.0141


Epoch [63/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [63/200], Loss: 0.0270


Epoch [64/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [64/200], Loss: 0.0497


Epoch [65/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [65/200], Loss: 0.0391


Epoch [66/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [66/200], Loss: 0.0312


Epoch [67/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [67/200], Loss: 0.0288


Epoch [68/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [68/200], Loss: 0.0111


Epoch [69/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [69/200], Loss: 0.0043


Epoch [70/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [70/200], Loss: 0.0049


Epoch [71/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [71/200], Loss: 0.0098


Epoch [72/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [72/200], Loss: 0.0121


Epoch [73/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [73/200], Loss: 0.0025


Epoch [74/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [74/200], Loss: 0.0023


Epoch [75/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [75/200], Loss: 0.0010


Epoch [76/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [76/200], Loss: 0.0009


Epoch [77/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [77/200], Loss: 0.0007


Epoch [78/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [78/200], Loss: 0.0005


Epoch [79/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [79/200], Loss: 0.0006


Epoch [80/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [80/200], Loss: 0.0002


Epoch [81/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [81/200], Loss: 0.0004


Epoch [82/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [82/200], Loss: 0.0004


Epoch [83/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [83/200], Loss: 0.0005


Epoch [84/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [84/200], Loss: 0.0006


Epoch [85/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [85/200], Loss: 0.0002


Epoch [86/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [86/200], Loss: 0.0003


Epoch [87/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [87/200], Loss: 0.0002


Epoch [88/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [88/200], Loss: 0.0002


Epoch [89/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [89/200], Loss: 0.0002


Epoch [90/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [90/200], Loss: 0.0002


Epoch [91/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [91/200], Loss: 0.0002


Epoch [92/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [92/200], Loss: 0.0001


Epoch [93/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [93/200], Loss: 0.0001


Epoch [94/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [94/200], Loss: 0.0001


Epoch [95/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [95/200], Loss: 0.0002


Epoch [96/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [96/200], Loss: 0.0002


Epoch [97/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [97/200], Loss: 0.0002


Epoch [98/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [98/200], Loss: 0.0270


Epoch [99/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [99/200], Loss: 0.3994


Epoch [100/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [100/200], Loss: 0.1049


Epoch [101/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [101/200], Loss: 0.0362


Epoch [102/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [102/200], Loss: 0.0296


Epoch [103/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [103/200], Loss: 0.0100


Epoch [104/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [104/200], Loss: 0.0098


Epoch [105/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [105/200], Loss: 0.0052


Epoch [106/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [106/200], Loss: 0.0021


Epoch [107/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [107/200], Loss: 0.0018


Epoch [108/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [108/200], Loss: 0.0024


Epoch [109/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [109/200], Loss: 0.0033


Epoch [110/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [110/200], Loss: 0.0026


Epoch [111/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [111/200], Loss: 0.0016


Epoch [112/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [112/200], Loss: 0.0013


Epoch [113/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [113/200], Loss: 0.0040


Epoch [114/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [114/200], Loss: 0.0030


Epoch [115/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [115/200], Loss: 0.0043


Epoch [116/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [116/200], Loss: 0.0121


Epoch [117/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [117/200], Loss: 0.0137


Epoch [118/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [118/200], Loss: 0.0049


Epoch [119/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [119/200], Loss: 0.0045


Epoch [120/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [120/200], Loss: 0.0093


Epoch [121/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [121/200], Loss: 0.0028


Epoch [122/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [122/200], Loss: 0.0116


Epoch [123/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [123/200], Loss: 0.0250


Epoch [124/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [124/200], Loss: 0.0234


Epoch [125/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [125/200], Loss: 0.0282


Epoch [126/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [126/200], Loss: 0.0075


Epoch [127/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [127/200], Loss: 0.0023


Epoch [128/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [128/200], Loss: 0.0031


Epoch [129/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [129/200], Loss: 0.0015


Epoch [130/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [130/200], Loss: 0.0081


Epoch [131/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [131/200], Loss: 0.0094


Epoch [132/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [132/200], Loss: 0.0044


Epoch [133/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [133/200], Loss: 0.0071


Epoch [134/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [134/200], Loss: 0.0063


Epoch [135/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [135/200], Loss: 0.0118


Epoch [136/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [136/200], Loss: 0.0023


Epoch [137/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [137/200], Loss: 0.0011


Epoch [138/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [138/200], Loss: 0.0007


Epoch [139/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [139/200], Loss: 0.0005


Epoch [140/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [140/200], Loss: 0.0006


Epoch [141/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [141/200], Loss: 0.0004


Epoch [142/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [142/200], Loss: 0.0005


Epoch [143/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [143/200], Loss: 0.0003


Epoch [144/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [144/200], Loss: 0.0003


Epoch [145/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [145/200], Loss: 0.0003


Epoch [146/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [146/200], Loss: 0.0003


Epoch [147/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [147/200], Loss: 0.0003


Epoch [148/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [148/200], Loss: 0.0211


Epoch [149/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [149/200], Loss: 0.1799


Epoch [150/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [150/200], Loss: 0.0437


Epoch [151/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [151/200], Loss: 0.0209


Epoch [152/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [152/200], Loss: 0.0100


Epoch [153/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [153/200], Loss: 0.0025


Epoch [154/200]: 100%|██████████| 90/90 [00:24<00:00,  3.67batch/s]


Epoch [154/200], Loss: 0.0017


Epoch [155/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [155/200], Loss: 0.0020


Epoch [156/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [156/200], Loss: 0.0012


Epoch [157/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [157/200], Loss: 0.0013


Epoch [158/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [158/200], Loss: 0.0014


Epoch [159/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [159/200], Loss: 0.0064


Epoch [160/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [160/200], Loss: 0.0018


Epoch [161/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [161/200], Loss: 0.0017


Epoch [162/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [162/200], Loss: 0.0091


Epoch [163/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [163/200], Loss: 0.0087


Epoch [164/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [164/200], Loss: 0.0018


Epoch [165/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [165/200], Loss: 0.0006


Epoch [166/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [166/200], Loss: 0.0009


Epoch [167/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [167/200], Loss: 0.0024


Epoch [168/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [168/200], Loss: 0.0081


Epoch [169/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [169/200], Loss: 0.0068


Epoch [170/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [170/200], Loss: 0.0200


Epoch [171/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [171/200], Loss: 0.0063


Epoch [172/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [172/200], Loss: 0.0027


Epoch [173/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [173/200], Loss: 0.0027


Epoch [174/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [174/200], Loss: 0.0010


Epoch [175/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [175/200], Loss: 0.0034


Epoch [176/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [176/200], Loss: 0.0008


Epoch [177/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [177/200], Loss: 0.0010


Epoch [178/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [178/200], Loss: 0.0011


Epoch [179/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [179/200], Loss: 0.0011


Epoch [180/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [180/200], Loss: 0.0005


Epoch [181/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [181/200], Loss: 0.0006


Epoch [182/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [182/200], Loss: 0.0004


Epoch [183/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [183/200], Loss: 0.0016


Epoch [184/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [184/200], Loss: 0.0007


Epoch [185/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [185/200], Loss: 0.0007


Epoch [186/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [186/200], Loss: 0.0009


Epoch [187/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [187/200], Loss: 0.0006


Epoch [188/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [188/200], Loss: 0.0007


Epoch [189/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [189/200], Loss: 0.0003


Epoch [190/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [190/200], Loss: 0.0006


Epoch [191/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [191/200], Loss: 0.0002


Epoch [192/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [192/200], Loss: 0.0005


Epoch [193/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [193/200], Loss: 0.0009


Epoch [194/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [194/200], Loss: 0.0009


Epoch [195/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [195/200], Loss: 0.0004


Epoch [196/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [196/200], Loss: 0.0002


Epoch [197/200]: 100%|██████████| 90/90 [00:24<00:00,  3.69batch/s]


Epoch [197/200], Loss: 0.0002


Epoch [198/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [198/200], Loss: 0.0006


Epoch [199/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [199/200], Loss: 0.0005


Epoch [200/200]: 100%|██████████| 90/90 [00:24<00:00,  3.68batch/s]


Epoch [200/200], Loss: 0.0007
Validation Loss: 1.8025


In [7]:
# Save val_samples
torch.save(test_samples, "test_samples.pth")

# Save val_dataset
torch.save(test_labels, "test_labels.pth")
