In [3]:
import torch
import torch.nn as nn
import torch.optim as optim 

In [5]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print()
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


In [6]:
torch.save(model, 'model.pth')

In [10]:
m2 = torch.load('model.pth')
for param_tensor in m2.state_dict():
    print(param_tensor, "\t", m2.state_dict()[param_tensor].size())

conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


In [21]:
from config import *

mconfig = eval("vanilla")

model_fn = mconfig.model
litmodel = model_fn(mconfig)

for param_tensor in litmodel.state_dict():
    print(param_tensor, "\t", litmodel.state_dict()[param_tensor].size())

torch.save(litmodel, 'litmodel.pth')

link_prediction_head.net.0.weight 	 torch.Size([256, 512])
link_prediction_head.net.0.bias 	 torch.Size([256])
link_prediction_head.net.1.weight 	 torch.Size([256])
link_prediction_head.net.1.bias 	 torch.Size([256])
link_prediction_head.net.3.weight 	 torch.Size([1, 256])
link_prediction_head.net.3.bias 	 torch.Size([1])
node_embedding.weight 	 torch.Size([19502, 128])
gene_prediction_head.net.0.weight 	 torch.Size([256, 256])
gene_prediction_head.net.0.bias 	 torch.Size([256])
gene_prediction_head.net.1.weight 	 torch.Size([256])
gene_prediction_head.net.1.bias 	 torch.Size([256])
gene_prediction_head.net.3.weight 	 torch.Size([19247, 256])
gene_prediction_head.net.3.bias 	 torch.Size([19247])
rank_prediction_head.net.0.weight 	 torch.Size([256, 256])
rank_prediction_head.net.0.bias 	 torch.Size([256])
rank_prediction_head.net.1.weight 	 torch.Size([256])
rank_prediction_head.net.1.bias 	 torch.Size([256])
rank_prediction_head.net.3.weight 	 torch.Size([255, 256])
rank_prediction_hea

In [20]:
m2 = torch.load('litmodel.pth')
for param_tensor in m2.state_dict():
    print(param_tensor, "\t", m2.state_dict()[param_tensor].size())

link_prediction_head.net.0.weight 	 torch.Size([256, 512])
link_prediction_head.net.0.bias 	 torch.Size([256])
link_prediction_head.net.1.weight 	 torch.Size([256])
link_prediction_head.net.1.bias 	 torch.Size([256])
link_prediction_head.net.3.weight 	 torch.Size([1, 256])
link_prediction_head.net.3.bias 	 torch.Size([1])
node_embedding.weight 	 torch.Size([19502, 128])
gene_prediction_head.net.0.weight 	 torch.Size([256, 256])
gene_prediction_head.net.0.bias 	 torch.Size([256])
gene_prediction_head.net.1.weight 	 torch.Size([256])
gene_prediction_head.net.1.bias 	 torch.Size([256])
gene_prediction_head.net.3.weight 	 torch.Size([19247, 256])
gene_prediction_head.net.3.bias 	 torch.Size([19247])
rank_prediction_head.net.0.weight 	 torch.Size([256, 256])
rank_prediction_head.net.0.bias 	 torch.Size([256])
rank_prediction_head.net.1.weight 	 torch.Size([256])
rank_prediction_head.net.1.bias 	 torch.Size([256])
rank_prediction_head.net.3.weight 	 torch.Size([255, 256])
rank_prediction_hea

In [3]:
from config import *
import pickle

with open('/hpc/mydata/leo.dupire/GLM/model_out/nv8uMjIF/mconfig_used.json', 'rb') as file:  # 'rb' mode for reading binary
    data = pickle.load(file)  # Load the pickled data

print(data)

{'model': <class 'flash_transformer.GDTransformer'>, 'model_config': {'num_ranks': 255, 'num_genes': 19247, 'node_embedding_dim': 128}, 'transformer_config': {'transformer_dim': {'input_dim': 256, 'feed_dim': 256, 'hidden_dims': [256, 256, 256], 'conv_dim': 256, 'out_dim': 256}, 'num_heads': 8, 'num_encoder_layers': 1, 'activation': 'gelu', 'dropout': 0.1, 'batch_first': True, 'use_pe': False, 'use_attn_mask': False, 'use_flash_attn': True}, 'data_config': {'train': {'cache_dir': '/hpc/projects/group.califano/GLM/data/cxg_cache_4096/train', 'dataset_name': 'train'}, 'val': [{'cache_dir': '/hpc/projects/group.califano/GLM/data/cxg_cache_4096/valSG', 'dataset_name': 'valSG'}, {'cache_dir': '/hpc/projects/group.califano/GLM/data/cxg_cache_4096/valHOG', 'dataset_name': 'valHOG'}], 'test': [], 'run_test': False, 'num_workers': 1, 'batch_size': 16}, 'trainer_config': {'max_epochs': 100, 'accelerator': 'gpu', 'max_time': '01:00:00:00', 'devices': 1, 'precision': 'bf16', 'num_sanity_val_steps'