In [1]:
# Reload all modules in case they are under development
import importlib
from dkst.utils import KST_utils, DKST_utils, set_operations, relations
from dkst import dkst_datasets, models

# Reload each module
importlib.reload(KST_utils)
importlib.reload(DKST_utils)
importlib.reload(set_operations)
importlib.reload(relations)
importlib.reload(dkst_datasets)
importlib.reload(models)

# Import everything from the modules
from dkst.utils.KST_utils import *
from dkst.utils.DKST_utils import *
from dkst.utils.set_operations import *
from dkst.utils.relations import *
from dkst.dkst_datasets import *
from dkst.models import *

import gc
import tracemalloc
from torch.utils.data import DataLoader, random_split

print("Modules reloaded and re-imported successfully.")

Modules reloaded and re-imported successfully.


In [None]:
# Dataset configuration 
config_path = os.path.abspath("../data/config/config_data_04.1.json")
D0 = DKSTDataset02(config_path)
for i in [2,3,4,5]:
    config_path = os.path.abspath(f"../data/config/config_data_04.{i}.json")
    D0 += DKSTDataset02(config_path)

# split dataset 80/20
train_size = int(0.8 * len(D0))
test_size = len(D0) - train_size
D_train, D_test = random_split(D0, [train_size, test_size])
#D4 = D2 + D3

sample = D_train.__getitem__(0)
print()
print("Length train set: ", len(D_train))
print("Shape conditionals:       ", sample[0].shape)
print("Shape input sequence:     ", sample[1].shape) 
print("Shape target sequence:    ", sample[2].shape) 
print("Shape input observations: ", sample[3].shape)

# model
model = CustomDecoderModel(config_path)
device = "mps"
model = model.to(device)

dataloader = DataLoader(D_train, batch_size=4, shuffle=True, collate_fn=collate_fn)


# get one batch 
for i_batch, sample_batched in enumerate(dataloader):
    print()
    print("Batch size: ", len(sample_batched))
    print("Shape conditionals:       ", sample_batched[0].shape)
    print("Shape input sequence:     ", sample_batched[1].shape) 
    print("Shape target sequence:    ", sample_batched[2].shape) 
    print("Shape input observations: ", sample_batched[3].shape)
    break

for i_batch, sample_batched in enumerate(dataloader):
    conditionals, input_seq, target_seq, input_obs = sample_batched
    conditionals = conditionals.to(device)
    input_seq = input_seq.to(device)
    target_seq = target_seq.to(device)
    input_obs = input_obs.to(device)
    output, embedding, attention_weights = model.forward(conditionals, input_seq)
    print()
    print("Output shape: ", output.shape) 
    print("Embedding shape: ", embedding.shape)
    print()
    
    print("Attention weights:")
    for i, attn_weight in enumerate(attention_weights):
        print(f"Layer {i} attention weights shape: {attn_weight.shape}")
    break

In [5]:
model = CustomDecoderModel(config_path) 
device = "mps" #"cpu"
model = model.to(device)
batch_size = 4

train_loader = DataLoader(D_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(D_test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

ce_loss = CustomCELoss()
ln_loss = LengthNormLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
clip_norm = 2
n_epochs = 70
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [None]:
ln_wheight = 0.001
penalty_weight = 0

In [14]:

# Start tracing memory allocations
#tracemalloc.start()

for epoch in range(n_epochs):
    # Training
    train_ce_loss, train_ln_loss, train_combined_loss = train(
        model, train_loader, ce_loss, ln_loss, ln_wheight, optimizer, penalty_weight=penalty_weight, clip_norm=clip_norm, knet=None, device=device, prediction_only=False
    )                                                                               # 0.3
    # Evaluation
    eval_ce_loss, eval_ln_loss, eval_combined_loss = eval(
        model, eval_loader, ce_loss, ln_loss, ln_wheight, knet=None, penalty_weight=penalty_weight, device=device, prediction_only=False
    )

    #mean_ce, mean_ln, mean_combined = eval_with_mc(model, eval_loader, ce_loss, ln_loss, ln_wheight, knet=None, prediction_only=False, n_mc_samples=3)
    
    # Print training and evaluation metrics
    print(f"Epoch {epoch+1}/{n_epochs}")
    print(f"Train CE Loss: {train_ce_loss:.4f}, Train LN Loss: {train_ln_loss:.4f}, Train Combined Loss: {train_combined_loss:.4f}")
    print(f"Eval CE Loss: {eval_ce_loss:.4f}, Eval LN Loss: {eval_ln_loss:.4f}, Eval Combined Loss: {eval_combined_loss:.4f}")
    #print(f"MCDP CE Loss: {mean_ce:.4f}, MCDP LN Loss: {mean_ln:.4f}, MCDP Combined Loss: {mean_combined:.4f}")
    
    # Step the learning rate scheduler
    lr_scheduler.step()
    penalty_weight = max(0, penalty_weight - 0.1)
    ln_wheight *= 0.75
    print()

    # Print memory usage
    #current, peak = tracemalloc.get_traced_memory()
    print("Allocated memory:", torch.mps.current_allocated_memory() / (1024 ** 2), "MB | ", 
          f"Driver allocated memory: {torch.mps.driver_allocated_memory() / (1024 ** 2)} MB")

    # Clear cache and run garbage collector
    torch.mps.empty_cache()
    gc.collect()

# Stop tracing memory allocations
#tracemalloc.stop()

Epoch 1/15
Train CE Loss: 0.0991, Train LN Loss: 0.0000, Train Combined Loss: 0.0991
Eval CE Loss: 0.0111, Eval LN Loss: 0.0000, Eval Combined Loss: 0.0111

Allocated memory: 369.6015625 MB |  Driver allocated memory: 12395.296875 MB
Epoch 2/15
Train CE Loss: 0.0999, Train LN Loss: 0.0000, Train Combined Loss: 0.0999
Eval CE Loss: 0.0119, Eval LN Loss: 0.0000, Eval Combined Loss: 0.0119

Allocated memory: 370.5537109375 MB |  Driver allocated memory: 12371.296875 MB
Epoch 3/15
Train CE Loss: 0.0995, Train LN Loss: 0.0000, Train Combined Loss: 0.0995
Eval CE Loss: 0.0115, Eval LN Loss: 0.0000, Eval Combined Loss: 0.0116

Allocated memory: 370.5537109375 MB |  Driver allocated memory: 12749.296875 MB
Epoch 4/15
Train CE Loss: 0.0996, Train LN Loss: 0.0000, Train Combined Loss: 0.0996
Eval CE Loss: 0.0118, Eval LN Loss: 0.0000, Eval Combined Loss: 0.0118

Allocated memory: 370.55419921875 MB |  Driver allocated memory: 13487.296875 MB
Epoch 5/15
Train CE Loss: 0.0989, Train LN Loss: 0.000

RuntimeError: MPS backend out of memory (MPS allocated: 361.26 MB, other allocations: 17.77 GB, max allowed: 18.13 GB). Tried to allocate 12.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

To do: Fix penalty loss, such that it only considers states prior to eos in the target sequence (and approaches 0 earli on in training), or remove... 

In [15]:
for i in range(10):
    #conditionals, input_seq, target_seq, input_obs =  D2[0]
    dataloader = DataLoader(D_test, batch_size=1, shuffle=True, collate_fn=collate_fn)
    # Fetch a sample from the DataLoader
    for conditionals, input_seq, target_seq, input_obs in dataloader:
        break  # We only need one sample
    conditionals = conditionals.to(device)
    input_seq = input_seq.to(device)
    input_obs = input_obs.to(device)

    seq, emb = generate_sequence(model, conditionals.to(device), device=device)
    seq_MCDP, emb_MCDP = generate_sequence_MCDP(model, conditionals.to(device), device=device)
    print("Target sequence:    ", [t for t in [0]+target_seq.tolist()[0] if t != model.vocab_size-1])
    print("Generated sequence: ", seq)
    print("MCDP sequence:      ", seq_MCDP)
    print("Embedding shape: ", emb.shape)

Target sequence:     [0, 3, 7, 15, 16, 18, 20, 24, 25, 27, 29, 30, 31, 32, 33, 36, 37, 39, 40, 42, 45, 46, 48, 49, 50, 51, 52, 55, 56, 59, 62, 63, 64]
Generated sequence:  [0, 3, 7, 15, 16, 18, 20, 24, 25, 27, 29, 30, 31, 32, 33, 36, 37, 39, 40, 42, 45, 46, 48, 49, 50, 51, 52, 55, 56, 59, 62, 63, 64]
MCDP sequence:       [0, 3, 7, 14, 15, 16, 18, 20, 24, 25, 27, 29, 30, 31, 32, 33, 36, 37, 39, 40, 42, 45, 46, 48, 49, 50, 51, 52, 55, 56, 59, 62, 63, 64]
Embedding shape:  torch.Size([1, 1024])
Target sequence:     [0, 3, 4, 9, 10, 13, 14, 15, 16, 20, 23, 26, 30, 31, 32, 33, 36, 38, 39, 40, 44, 45, 50, 52, 53, 54, 56, 58, 60, 62, 63, 64]
Generated sequence:  [0, 3, 4, 6, 9, 10, 13, 14, 15, 16, 20, 23, 26, 30, 31, 32, 33, 36, 38, 39, 40, 44, 45, 50, 52, 53, 54, 56, 58, 60, 62, 63, 64]
MCDP sequence:       [0, 3, 4, 9, 10, 13, 14, 15, 16, 20, 23, 26, 30, 31, 32, 33, 36, 38, 39, 40, 44, 45, 50, 52, 53, 54, 56, 58, 60, 62, 63, 64]
Embedding shape:  torch.Size([1, 1024])
Target sequence:     [

In [16]:
print("Performance test without MC dropout.")
acc, mean, std = performance(model, dataloader, num_samples=100)
print("Accuracy:      ", acc)
print("Mean distance: ", mean)
print("Std:           ", std)

n_mc_samples = [2,3,5,30]
for i in n_mc_samples:
    print(f"Performance test with MC dropout and {i} mcdp runs per sample.")
    acc, mean, std = performance(model, dataloader, n_mc_samples=i, num_samples=100)
    print("Accuracy:      ", acc)
    print("Mean distance: ", mean)
    print("Std:           ", std)
    print()



Performance test without MC dropout.


Testing performance with MCDP disabled...: 100%|██████████| 100/100 [00:13<00:00,  7.60it/s]


Accuracy:       0.74
Mean distance:  0.46
Std:            1.135077089893017
Performance test with MC dropout and 2 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [00:28<00:00,  3.49it/s]


Accuracy:       0.31
Mean distance:  1.88
Std:            1.8669761648183942

Performance test with MC dropout and 3 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [00:41<00:00,  2.43it/s]


Accuracy:       0.62
Mean distance:  0.63
Std:            1.0455142275454696

Performance test with MC dropout and 5 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [01:04<00:00,  1.55it/s]


Accuracy:       0.67
Mean distance:  0.55
Std:            1.0712142642814275

Performance test with MC dropout and 30 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [06:54<00:00,  4.15s/it]

Accuracy:       0.82
Mean distance:  0.22
Std:            0.5211525688318154




