In [1]:
import functools
import enum
import os

from BH.data_loader import *
from BH.generate_data import *
from training_info import *
# from Model_e import Model_e,Direction,Reduction
from Train import train,print_accuracies
from torch_geometric.loader import DataLoader


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device ="cuda:0"
use_pretrained_weights = True  #@param{type:"boolean"}
hold_graphs_in_memory = False  #@param{type:"boolean"}

gb = 1024**3
total_memory = psutil.virtual_memory().total / gb
if total_memory < 20 and hold_graphs_in_memory:
    raise RuntimeError(f"It is unlikely your machine (with {total_memory}Gb) will have enough memory to complete the colab's execution!")

print("Loading input data...")
full_dataset, train_dataset, test_dataset = load_input_data(DIR_PATH)

  from .autonotebook import tqdm as notebook_tqdm


Loading input data...
Generating data from the directory /Data/Ptab/n=5_2row


In [2]:
from torch_geometric.data import Data
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
import torch
class CustomDataset(Dataset):
    def __init__(self, input_data):
        self.features = input_data.features
        self.labels = input_data.labels
        self.rows = input_data.rows
        self.cols = input_data.columns
        self.edge_types = input_data.edge_types

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        edge_index = torch.tensor([self.rows[idx], self.cols[idx]], dtype=torch.long)
        return Data(x=torch.from_numpy(self.features[idx]).float(), edge_index=edge_index, 
             edge_types = torch.tensor(self.edge_types[idx][:, np.newaxis], dtype=torch.float),
             y=torch.from_numpy(np.array(self.labels[idx])))

In [3]:
node_dim=64
edge_dim=8
graph_deg=5
batch_size=32

test_dataset = CustomDataset(test_dataset)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

train_dataset = CustomDataset(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv,GCNConv


class GCN_single(torch.nn.Module):
    def __init__(self,num_edge_types,depth):
        super().__init__()
        self.num_edge_types= num_edge_types
        self.depth=depth
        self.node_linear = torch.nn.Linear(1,node_dim)
        self.edge_linear = torch.nn.Linear(1,edge_dim)
        self.conv1 = GCNConv(node_dim, node_dim)
        self.conv2 = GCNConv(node_dim, node_dim)
        self.conv3 = GCNConv(node_dim, node_dim)
        self.conv4 = GCNConv(node_dim, node_dim)
        
        self.out1 = torch.nn.Linear(node_dim,node_dim)
        self.out2 = torch.nn.Linear(node_dim,1)
        self.initialize_parameters()

    def initialize_parameters(self):
        for module in self.modules():
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
        

    def forward(self, data):
        x, edge_index, edge_types = data.x, data.edge_index, data.edge_types
        x = self.node_linear(x)        
        for i in range(self.depth):
            mask1 = (edge_types == 1)
            edge_index1 = edge_index[:, mask1.squeeze()]
            x1 = self.conv1(x, edge_index1)
            x1 = F.relu(x1)
            
            mask2 = (edge_types == 2)
            edge_index2 = edge_index[:, mask2.squeeze()]
            x2 = self.conv2(x, edge_index2)
            x2 = F.relu(x2)
            
            mask3 = (edge_types == 3)
            edge_index3 = edge_index[:, mask3.squeeze()]
            x3 = self.conv3(x, edge_index3)
            x3 = F.relu(x3)
            
            mask4 = (edge_types == 4)
            edge_index4 = edge_index[:, mask4.squeeze()]
            x4 = self.conv4(x, edge_index4)
            x4 = F.relu(x4)
            
            x = x1+x2+x3+x4
            x = F.relu(x)
        xx = torch.reshape(x,(-1,graph_deg,node_dim))
        xxx,_ = torch.max(xx,dim=1)
        xxx = self.out1(xxx)
        xxx = self.out2(xxx)
        
        return xxx
#         return F.log_softmax(xxx, dim=1)
    
class GCN_multi(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.num_edge_types=4
        self.depth =5
        self.GCN_single = GCN_single(self.num_edge_types, self.depth)        

    def forward(self, data,T=1):
        batch = data.batch
        batch = batch[::5]
        x = self.GCN_single(data)
        x=torch.sigmoid(x/T)
        unique_batches = torch.unique(batch)
        
#         sums_tensor = torch.zeros(len(unique_batches), requires_grad=True)

#         # Loop over unique batches and add elements of 'x' within each batch directly to sums_tensor
#         for i, ub in enumerate(unique_batches):
#             sums_tensor[i] = x[batch == ub].sum()

            
            # Loop over unique batches and sum elements of 'x' within each batch
        sum_list = [x[batch == ub].sum() for ub in unique_batches]

        # Stack list of sums to create a tensor
        sums_tensor = torch.stack(sum_list)
#         print(sums_tensor)
        
#         unique_batches = torch.unique(batch)
#         sums = []
#         for ub in unique_batches:
#             sums.append(x[batch == ub].sum())

#         # Convert list of sums to tensor
#         sums_tensor = torch.tensor(sums)
        
        return sums_tensor

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = pastGCN().to(device)
model = GCN_multi().to(device)
# data = batch.to(device)
# torch.nn.init.xavier_normal(model)
loss_function = torch.nn.CrossEntropyLoss()
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

In [6]:
test_dataset[2].y

  edge_index = torch.tensor([self.rows[idx], self.cols[idx]], dtype=torch.long)


tensor(0)

In [7]:
test_dataset[2].edge_index

tensor([[ 0,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  3,  3,  4,  5,  5,  5,
          5,  5,  6,  6,  6,  6,  7,  7,  7,  8,  8,  9, 10, 10, 10, 10, 10, 11,
         11, 11, 11, 12, 12, 12, 13, 13, 14],
        [ 0,  1,  2,  3,  4,  1,  2,  3,  4,  2,  3,  4,  3,  4,  4,  5,  6,  7,
          8,  9,  6,  7,  8,  9,  7,  8,  9,  8,  9,  9, 10, 11, 12, 13, 14, 11,
         12, 13, 14, 12, 13, 14, 13, 14, 14]])

In [8]:
for data in test_loader:
    break
print(data)

DataBatch(x=[240, 1], edge_index=[2, 720], y=[32], edge_types=[720, 1], batch=[240], ptr=[33])


In [9]:
data.batch[::5]

tensor([ 0,  1,  2,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 13, 13, 14,
        15, 16, 17, 17, 18, 19, 19, 19, 20, 20, 21, 22, 23, 23, 23, 24, 25, 26,
        26, 27, 28, 28, 28, 29, 29, 29, 30, 31, 31, 31])

In [10]:
num_epochs=1000
for epoch in range(num_epochs):
    # Training phase
    model.train()
    for batch in train_loader:
        batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        
        batch.y = batch.y.float()
        loss = loss_function(out, batch.y)
        loss.backward()
        optimizer.step()
    print(loss)
    
    # Evaluation phase
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in train_loader:
            batch.to(device)
            outputs = model(batch)
#             _,predicted = torch.max(outputs.data, 1)
            predicted = outputs
            total += batch.y.size(0)
#             correct += (predicted == batch.y).sum().item()
            correct += ((predicted - batch.y)**2<0.1).sum().item()

    # Compute accuracy
    accuracy = correct / total

    print("Epoch [{}/{}], Accuracy: {:.2%}".format(epoch + 1, num_epochs, accuracy))

tensor(1.1250, grad_fn=<MseLossBackward0>)
Epoch [1/1000], Accuracy: 31.25%
tensor(0.5000, grad_fn=<MseLossBackward0>)
Epoch [2/1000], Accuracy: 31.25%
tensor(1., grad_fn=<MseLossBackward0>)
Epoch [3/1000], Accuracy: 31.25%
tensor(0.7500, grad_fn=<MseLossBackward0>)
Epoch [4/1000], Accuracy: 31.25%
tensor(1.2500, grad_fn=<MseLossBackward0>)
Epoch [5/1000], Accuracy: 31.25%
tensor(0.6250, grad_fn=<MseLossBackward0>)
Epoch [6/1000], Accuracy: 31.25%
tensor(0.7500, grad_fn=<MseLossBackward0>)
Epoch [7/1000], Accuracy: 31.25%
tensor(0.5625, grad_fn=<MseLossBackward0>)
Epoch [8/1000], Accuracy: 31.25%
tensor(0.6250, grad_fn=<MseLossBackward0>)
Epoch [9/1000], Accuracy: 31.25%
tensor(0.8750, grad_fn=<MseLossBackward0>)
Epoch [10/1000], Accuracy: 31.25%
tensor(0.8125, grad_fn=<MseLossBackward0>)
Epoch [11/1000], Accuracy: 31.25%
tensor(1.1875, grad_fn=<MseLossBackward0>)
Epoch [12/1000], Accuracy: 31.25%
tensor(0.8125, grad_fn=<MseLossBackward0>)
Epoch [13/1000], Accuracy: 31.25%
tensor(0.500

tensor(0.3931, grad_fn=<MseLossBackward0>)
Epoch [108/1000], Accuracy: 14.35%
tensor(0.3863, grad_fn=<MseLossBackward0>)
Epoch [109/1000], Accuracy: 43.98%
tensor(0.4390, grad_fn=<MseLossBackward0>)
Epoch [110/1000], Accuracy: 16.67%
tensor(0.2638, grad_fn=<MseLossBackward0>)
Epoch [111/1000], Accuracy: 16.90%
tensor(0.2656, grad_fn=<MseLossBackward0>)
Epoch [112/1000], Accuracy: 41.90%
tensor(0.2108, grad_fn=<MseLossBackward0>)
Epoch [113/1000], Accuracy: 19.68%
tensor(0.2623, grad_fn=<MseLossBackward0>)
Epoch [114/1000], Accuracy: 43.29%
tensor(0.2614, grad_fn=<MseLossBackward0>)
Epoch [115/1000], Accuracy: 58.10%
tensor(0.2799, grad_fn=<MseLossBackward0>)
Epoch [116/1000], Accuracy: 55.79%
tensor(0.2738, grad_fn=<MseLossBackward0>)
Epoch [117/1000], Accuracy: 46.99%
tensor(0.1497, grad_fn=<MseLossBackward0>)
Epoch [118/1000], Accuracy: 50.00%
tensor(0.4406, grad_fn=<MseLossBackward0>)
Epoch [119/1000], Accuracy: 47.69%
tensor(0.1779, grad_fn=<MseLossBackward0>)
Epoch [120/1000], Acc

tensor(0.0705, grad_fn=<MseLossBackward0>)
Epoch [214/1000], Accuracy: 75.23%
tensor(0.1273, grad_fn=<MseLossBackward0>)
Epoch [215/1000], Accuracy: 86.11%
tensor(0.0243, grad_fn=<MseLossBackward0>)
Epoch [216/1000], Accuracy: 84.03%
tensor(0.0946, grad_fn=<MseLossBackward0>)
Epoch [217/1000], Accuracy: 71.99%
tensor(0.0831, grad_fn=<MseLossBackward0>)
Epoch [218/1000], Accuracy: 87.27%
tensor(0.0913, grad_fn=<MseLossBackward0>)
Epoch [219/1000], Accuracy: 86.11%
tensor(0.0772, grad_fn=<MseLossBackward0>)
Epoch [220/1000], Accuracy: 85.42%
tensor(0.0286, grad_fn=<MseLossBackward0>)
Epoch [221/1000], Accuracy: 84.72%
tensor(0.0818, grad_fn=<MseLossBackward0>)
Epoch [222/1000], Accuracy: 83.56%
tensor(0.0332, grad_fn=<MseLossBackward0>)
Epoch [223/1000], Accuracy: 84.26%
tensor(0.0460, grad_fn=<MseLossBackward0>)
Epoch [224/1000], Accuracy: 84.95%
tensor(0.0465, grad_fn=<MseLossBackward0>)
Epoch [225/1000], Accuracy: 88.66%
tensor(0.0207, grad_fn=<MseLossBackward0>)
Epoch [226/1000], Acc

tensor(0.0078, grad_fn=<MseLossBackward0>)
Epoch [320/1000], Accuracy: 99.07%
tensor(0.0264, grad_fn=<MseLossBackward0>)
Epoch [321/1000], Accuracy: 98.84%
tensor(0.0045, grad_fn=<MseLossBackward0>)
Epoch [322/1000], Accuracy: 98.38%
tensor(0.0024, grad_fn=<MseLossBackward0>)
Epoch [323/1000], Accuracy: 96.99%
tensor(0.0059, grad_fn=<MseLossBackward0>)
Epoch [324/1000], Accuracy: 91.67%
tensor(0.0041, grad_fn=<MseLossBackward0>)
Epoch [325/1000], Accuracy: 96.99%
tensor(0.0011, grad_fn=<MseLossBackward0>)
Epoch [326/1000], Accuracy: 96.76%
tensor(0.0072, grad_fn=<MseLossBackward0>)
Epoch [327/1000], Accuracy: 98.15%
tensor(0.0036, grad_fn=<MseLossBackward0>)
Epoch [328/1000], Accuracy: 98.15%
tensor(0.0166, grad_fn=<MseLossBackward0>)
Epoch [329/1000], Accuracy: 97.45%
tensor(0.0221, grad_fn=<MseLossBackward0>)
Epoch [330/1000], Accuracy: 98.15%
tensor(0.0479, grad_fn=<MseLossBackward0>)
Epoch [331/1000], Accuracy: 85.42%
tensor(0.0313, grad_fn=<MseLossBackward0>)
Epoch [332/1000], Acc

tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [425/1000], Accuracy: 96.99%
tensor(0.0081, grad_fn=<MseLossBackward0>)
Epoch [426/1000], Accuracy: 98.84%
tensor(0.0411, grad_fn=<MseLossBackward0>)
Epoch [427/1000], Accuracy: 99.31%
tensor(0.0043, grad_fn=<MseLossBackward0>)
Epoch [428/1000], Accuracy: 99.07%
tensor(0.0119, grad_fn=<MseLossBackward0>)
Epoch [429/1000], Accuracy: 99.77%
tensor(0.0027, grad_fn=<MseLossBackward0>)
Epoch [430/1000], Accuracy: 99.31%
tensor(0.0037, grad_fn=<MseLossBackward0>)
Epoch [431/1000], Accuracy: 99.31%
tensor(0.0071, grad_fn=<MseLossBackward0>)
Epoch [432/1000], Accuracy: 99.54%
tensor(0.0010, grad_fn=<MseLossBackward0>)
Epoch [433/1000], Accuracy: 100.00%
tensor(0.0011, grad_fn=<MseLossBackward0>)
Epoch [434/1000], Accuracy: 100.00%
tensor(0.0109, grad_fn=<MseLossBackward0>)
Epoch [435/1000], Accuracy: 99.77%
tensor(0.0068, grad_fn=<MseLossBackward0>)
Epoch [436/1000], Accuracy: 99.07%
tensor(0.0113, grad_fn=<MseLossBackward0>)
Epoch [437/1000], A

tensor(0.0004, grad_fn=<MseLossBackward0>)
Epoch [530/1000], Accuracy: 97.69%
tensor(0.0044, grad_fn=<MseLossBackward0>)
Epoch [531/1000], Accuracy: 93.98%
tensor(0.0027, grad_fn=<MseLossBackward0>)
Epoch [532/1000], Accuracy: 98.84%
tensor(0.0006, grad_fn=<MseLossBackward0>)
Epoch [533/1000], Accuracy: 98.84%
tensor(0.0262, grad_fn=<MseLossBackward0>)
Epoch [534/1000], Accuracy: 94.44%
tensor(0.0099, grad_fn=<MseLossBackward0>)
Epoch [535/1000], Accuracy: 99.07%
tensor(0.0043, grad_fn=<MseLossBackward0>)
Epoch [536/1000], Accuracy: 97.92%
tensor(0.0086, grad_fn=<MseLossBackward0>)
Epoch [537/1000], Accuracy: 96.99%
tensor(0.0353, grad_fn=<MseLossBackward0>)
Epoch [538/1000], Accuracy: 98.38%
tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [539/1000], Accuracy: 99.54%
tensor(0.0028, grad_fn=<MseLossBackward0>)
Epoch [540/1000], Accuracy: 99.77%
tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [541/1000], Accuracy: 100.00%
tensor(0.0035, grad_fn=<MseLossBackward0>)
Epoch [542/1000], Ac

tensor(0.0002, grad_fn=<MseLossBackward0>)
Epoch [635/1000], Accuracy: 100.00%
tensor(0.0008, grad_fn=<MseLossBackward0>)
Epoch [636/1000], Accuracy: 100.00%
tensor(0.0004, grad_fn=<MseLossBackward0>)
Epoch [637/1000], Accuracy: 99.31%
tensor(0.0048, grad_fn=<MseLossBackward0>)
Epoch [638/1000], Accuracy: 100.00%
tensor(0.0023, grad_fn=<MseLossBackward0>)
Epoch [639/1000], Accuracy: 97.69%
tensor(0.0009, grad_fn=<MseLossBackward0>)
Epoch [640/1000], Accuracy: 94.44%
tensor(4.7598e-05, grad_fn=<MseLossBackward0>)
Epoch [641/1000], Accuracy: 100.00%
tensor(0.0359, grad_fn=<MseLossBackward0>)
Epoch [642/1000], Accuracy: 99.77%
tensor(0.0085, grad_fn=<MseLossBackward0>)
Epoch [643/1000], Accuracy: 99.77%
tensor(0.0051, grad_fn=<MseLossBackward0>)
Epoch [644/1000], Accuracy: 99.31%
tensor(4.3523e-05, grad_fn=<MseLossBackward0>)
Epoch [645/1000], Accuracy: 100.00%
tensor(6.5679e-05, grad_fn=<MseLossBackward0>)
Epoch [646/1000], Accuracy: 100.00%
tensor(3.2006e-05, grad_fn=<MseLossBackward0>)

tensor(0.0007, grad_fn=<MseLossBackward0>)
Epoch [739/1000], Accuracy: 99.54%
tensor(0.0065, grad_fn=<MseLossBackward0>)
Epoch [740/1000], Accuracy: 96.53%
tensor(0.0068, grad_fn=<MseLossBackward0>)
Epoch [741/1000], Accuracy: 96.30%
tensor(0.0649, grad_fn=<MseLossBackward0>)
Epoch [742/1000], Accuracy: 95.14%
tensor(0.0114, grad_fn=<MseLossBackward0>)
Epoch [743/1000], Accuracy: 98.15%
tensor(0.0359, grad_fn=<MseLossBackward0>)
Epoch [744/1000], Accuracy: 98.61%
tensor(0.0069, grad_fn=<MseLossBackward0>)
Epoch [745/1000], Accuracy: 97.92%
tensor(0.0077, grad_fn=<MseLossBackward0>)
Epoch [746/1000], Accuracy: 97.22%
tensor(0.0013, grad_fn=<MseLossBackward0>)
Epoch [747/1000], Accuracy: 99.54%
tensor(0.0005, grad_fn=<MseLossBackward0>)
Epoch [748/1000], Accuracy: 96.06%
tensor(0.0008, grad_fn=<MseLossBackward0>)
Epoch [749/1000], Accuracy: 98.84%
tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [750/1000], Accuracy: 98.38%
tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [751/1000], Acc

Epoch [842/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [843/1000], Accuracy: 100.00%
tensor(1.1695e-05, grad_fn=<MseLossBackward0>)
Epoch [844/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [845/1000], Accuracy: 100.00%
tensor(0.0004, grad_fn=<MseLossBackward0>)
Epoch [846/1000], Accuracy: 100.00%
tensor(0.0007, grad_fn=<MseLossBackward0>)
Epoch [847/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [848/1000], Accuracy: 100.00%
tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [849/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [850/1000], Accuracy: 100.00%
tensor(0.0056, grad_fn=<MseLossBackward0>)
Epoch [851/1000], Accuracy: 99.31%
tensor(0.0002, grad_fn=<MseLossBackward0>)
Epoch [852/1000], Accuracy: 100.00%
tensor(4.1407e-05, grad_fn=<MseLossBackward0>)
Epoch [853/1000], Accuracy: 100.00%
tensor(0.0002, grad_fn=<MseLossBackward0>)
Epoch [854/1000], Accuracy: 100.00%
tensor(0.

tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [946/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [947/1000], Accuracy: 100.00%
tensor(2.4037e-05, grad_fn=<MseLossBackward0>)
Epoch [948/1000], Accuracy: 100.00%
tensor(0.0003, grad_fn=<MseLossBackward0>)
Epoch [949/1000], Accuracy: 100.00%
tensor(0.0002, grad_fn=<MseLossBackward0>)
Epoch [950/1000], Accuracy: 100.00%
tensor(9.5292e-05, grad_fn=<MseLossBackward0>)
Epoch [951/1000], Accuracy: 99.77%
tensor(0.0023, grad_fn=<MseLossBackward0>)
Epoch [952/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [953/1000], Accuracy: 100.00%
tensor(3.6821e-05, grad_fn=<MseLossBackward0>)
Epoch [954/1000], Accuracy: 100.00%
tensor(0.0001, grad_fn=<MseLossBackward0>)
Epoch [955/1000], Accuracy: 100.00%
tensor(2.6677e-05, grad_fn=<MseLossBackward0>)
Epoch [956/1000], Accuracy: 100.00%
tensor(0.0004, grad_fn=<MseLossBackward0>)
Epoch [957/1000], Accuracy: 100.00%
tensor(0.0002, grad_fn=<MseLossBackwa

In [11]:
loss

tensor(0.0006, grad_fn=<MseLossBackward0>)

In [12]:
out.shape

torch.Size([16])

In [13]:
test_dataset[0]

Data(x=[5, 1], edge_index=[2, 15], y=1, edge_types=[15, 1])

In [14]:
batch.x[0]

tensor([1.])

In [15]:
batch.y

tensor([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1])

In [16]:
batch

DataBatch(x=[85, 1], edge_index=[2, 255], y=[16], edge_types=[255, 1], batch=[85], ptr=[17])

In [17]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in train_loader:
        batch.to(device)
        outputs = model(batch)
        _, predicted = torch.max(outputs.data, 1)
        total += batch.y.size(0)
        correct += (predicted == batch.y).sum().item()
        
# Compute accuracy
accuracy = correct / total

print("Epoch [{}/{}], Accuracy: {:.2%}".format(epoch + 1, num_epochs, accuracy))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
from datetime import datetime
current_date = datetime.now().strftime('%Y%m%d')

# Define the path to save the model parameters
# You might want to modify this to a directory of your choice
path = f"./model_parameters_{current_date}.pth"

# Save the model parameters
torch.save(model.state_dict(), path)