In [4]:
# Reload all modules in case they are under development
import importlib
from dkst.utils import KST_utils, DKST_utils, set_operations, relations
from dkst import dkst_datasets, models

# Reload each module
importlib.reload(KST_utils)
importlib.reload(DKST_utils)
importlib.reload(set_operations)
importlib.reload(relations)
importlib.reload(dkst_datasets)
importlib.reload(models)

# Import everything from the modules
from dkst.utils.KST_utils import *
from dkst.utils.DKST_utils import *
from dkst.utils.set_operations import *
from dkst.utils.relations import *
from dkst.dkst_datasets import *
from dkst.models import *

import gc
from torch.utils.data import DataLoader, random_split, ConcatDataset

print("Modules reloaded and re-imported successfully.")
config_path = os.path.abspath("../data/config/config_data_07.json")

Modules reloaded and re-imported successfully.


In [5]:
# Dataset configuration 
D0 = DKSTDataset02(config_path)

# split dataset 80/20
train_size = int(0.8 * len(D0))
test_size = len(D0) - train_size
D_train, D_test = random_split(D0, [train_size, test_size])

sample = D_train.__getitem__(0)
print()
print("Length train set: ", len(D_train))
print("Shape conditionals:       ", sample[0].shape)
print("Shape input sequence:     ", sample[1].shape) 
print("Shape target sequence:    ", sample[2].shape) 
print("Shape input observations: ", sample[3].shape)

KeyboardInterrupt: 

In [None]:
# model
model = CustomDecoderModel(config_path)
device = model.config["device"]
model = model.to(device)

dataloader = DataLoader(D_train, batch_size=4, shuffle=True, collate_fn=collate_fn)

# get one batch 
for i_batch, sample_batched in enumerate(dataloader):
    print()
    print("Batch size: ", len(sample_batched))
    print("Shape conditionals:       ", sample_batched[0].shape)
    print("Shape input sequence:     ", sample_batched[1].shape) 
    print("Shape target sequence:    ", sample_batched[2].shape) 
    print("Shape input observations: ", sample_batched[3].shape)

    # Histogram to vizualize conditionals
    conditionals = sample_batched[0].numpy()[0]
    colors = ['green' if i in sample_batched[2][0] else 'red' for i in range(len(conditionals))]
    colors[0] = 'green'
    print("Conditionals histogram:")
    plt.bar(range(len(conditionals)), conditionals.flatten(), color=colors)
    plt.show()

    # Histogram to vizualize input observations
    input_obs = sample_batched[3].numpy()[0]
    print("Input observations histogram:")
    plt.bar(range(len(input_obs)), input_obs.flatten(), color=colors)
    plt.show()
    break

for i_batch, sample_batched in enumerate(dataloader):
    conditionals, input_seq, target_seq, input_obs = sample_batched
    conditionals = conditionals.to(device)
    input_seq = input_seq.to(device)
    target_seq = target_seq.to(device)
    input_obs = input_obs.to(device)
    output, embedding, attention_weights = model.forward(conditionals, input_seq)
    print()
    print("Output shape: ", output.shape) 
    print("Embedding shape: ", embedding.shape)
    print()
    
    print("Attention weights:")
    for i, attn_weight in enumerate(attention_weights):
        print(f"Layer {i} attention weights shape: {attn_weight.shape}")
    break

In [6]:
# save_dataset(D_train, data_type="train")
# save_dataset(D_test, data_type="test")

# D_train  = load_dataset("dataset_train_Q6_50.pth")
# D_test   = load_dataset("dataset_test_Q6_50.pth")

# class FilteredDataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         return self.data[idx]

# def filter_dataset(D, max_length=4):
#     D_filtered = []
#     # Get eos token id from a Subset of a ConcatDataset...
#     #eos_token_id = D.dataset.datasets[0].datasets[0].datasets[0].datasets[0].vocab_size - 2
#     eos_token_id = D.dataset.datasets[0].vocab_size - 2

#     for conditionals, input_seq, target_seq, input_obs in D:
#         if input_seq[max_length] < eos_token_id:
#             D_filtered.append((conditionals, input_seq, target_seq, input_obs))
    
#     filtered_dataset = FilteredDataset(D_filtered)

#     return filtered_dataset

# # Filter for seqeunces with length <= 2**5
# D_test_filtered = filter_dataset(D_test, max_length=32)
# D_train_filtered = filter_dataset(D_train, max_length=32)

# print("Length of filtered train set: ", len(D_train_filtered))
# print("Length of filtered test set: ", len(D_test_filtered))

# D_test_original = D_test
# D_train_original = D_train
# D_test = D_test_filtered
# D_train = D_train_filtered

In [6]:
model = CustomDecoderModel(config_path) 
device = model.config["device"]
model = model.to(device)
batch_size = 4

In [5]:

train_loader = DataLoader(D_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(D_test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

ce_loss = CustomCELoss()
ln_loss = LengthNormLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
clip_norm = 1.0
n_epochs = 50

In [6]:
ln_wheight = 0.0
penalty_weight = 0.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.85)

In [7]:

# Start tracing memory allocations
#tracemalloc.start()

for epoch in range(n_epochs):
    # Training
    train_ce_loss, train_ln_loss, train_combined_loss = train(
        model, train_loader, ce_loss, ln_loss, ln_wheight, optimizer, penalty_weight=penalty_weight, clip_norm=clip_norm, knet=None, device=device, prediction_only=False
    )                                                                               # 0.3
    # Evaluation
    eval_ce_loss, eval_ln_loss, eval_combined_loss = eval(
        model, eval_loader, ce_loss, ln_loss, ln_wheight, knet=None, penalty_weight=penalty_weight, device=device, prediction_only=False
    )

    #mean_ce, mean_ln, mean_combined = eval_with_mc(model, eval_loader, ce_loss, ln_loss, ln_wheight, knet=None, prediction_only=False, n_mc_samples=3)
    
    # Print training and evaluation metrics
    print(f"Epoch {epoch+1}/{n_epochs}")
    print(f"Train CE Loss: {train_ce_loss:.4f}, Train LN Loss: {train_ln_loss:.4f}, Train Combined Loss: {train_combined_loss:.4f}")
    print(f"Eval CE Loss: {eval_ce_loss:.4f}, Eval LN Loss: {eval_ln_loss:.4f}, Eval Combined Loss: {eval_combined_loss:.4f}")
    #print(f"MCDP CE Loss: {mean_ce:.4f}, MCDP LN Loss: {mean_ln:.4f}, MCDP Combined Loss: {mean_combined:.4f}")
    
    # Step the learning rate scheduler
    lr_scheduler.step()
    penalty_weight = max(0, penalty_weight - 0.05)
    ln_wheight *= 0.75
    print()


    # Clear cache and run garbage collector
    if model.config["device"] == "mps":
        torch.mps.empty_cache()
    elif model.config["device"] == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

# Stop tracing memory allocations
#tracemalloc.stop()

Epoch 1/50
Train CE Loss: 0.3583, Train LN Loss: 0.0000, Train Combined Loss: 0.3583
Eval CE Loss: 0.3106, Eval LN Loss: 0.0000, Eval Combined Loss: 0.3106

Allocated memory: 344.79052734375 MB |  Driver allocated memory: 1232.78125 MB
Epoch 2/50
Train CE Loss: 0.2879, Train LN Loss: 0.0000, Train Combined Loss: 0.2879
Eval CE Loss: 0.2756, Eval LN Loss: 0.0000, Eval Combined Loss: 0.2756

Allocated memory: 353.271484375 MB |  Driver allocated memory: 1240.78125 MB
Epoch 3/50
Train CE Loss: 0.2294, Train LN Loss: 0.0000, Train Combined Loss: 0.2294
Eval CE Loss: 0.2183, Eval LN Loss: 0.0000, Eval Combined Loss: 0.2183

Allocated memory: 357.93408203125 MB |  Driver allocated memory: 1248.78125 MB
Epoch 4/50
Train CE Loss: 0.1862, Train LN Loss: 0.0000, Train Combined Loss: 0.1862
Eval CE Loss: 0.1555, Eval LN Loss: 0.0000, Eval Combined Loss: 0.1555

Allocated memory: 362.8603515625 MB |  Driver allocated memory: 1248.78125 MB
Epoch 5/50
Train CE Loss: 0.1569, Train LN Loss: 0.0000, Tr

To do: Fix penalty loss, such that it only considers states prior to eos in the target sequence (and approaches 0 earli on in training), or remove... 

In [12]:
# Clear cache and run garbage collector
if model.config["device"] == "mps":
    torch.mps.empty_cache()
elif model.config["device"] == "cuda":
    torch.cuda.empty_cache()
gc.collect()

# Print memory usage
# print("Allocated memory:", torch.mps.current_allocated_memory() / (1024 ** 2), "MB | ", 
#         f"Driver allocated memory: {torch.mps.driver_allocated_memory() / (1024 ** 2)} MB")

for i in range(10):
    #conditionals, input_seq, target_seq, input_obs =  D2[0]
    dataloader = DataLoader(D_test, batch_size=1, shuffle=True, collate_fn=collate_fn)
    # Fetch a sample from the DataLoader
    for conditionals, input_seq, target_seq, input_obs in dataloader:
        break  # We only need one sample
    conditionals = conditionals.to(device)
    input_seq = input_seq.to(device)
    input_obs = input_obs.to(device)

    seq, emb = generate_sequence(model, conditionals.to(device), device=device)
    seq_MCDP, emb_MCDP = generate_sequence_MCDP(model, conditionals.to(device), device=device)
    print("Target sequence:    ", [t for t in [0]+target_seq.tolist()[0] if t != model.vocab_size-1])
    print("Generated sequence: ", seq)
    print("MCDP sequence:      ", seq_MCDP)
    print("Embedding shape: ", emb.shape)

Target sequence:     [0, 1, 3, 4, 9, 14, 15, 22, 24, 25, 26, 31, 34, 38, 42, 43, 44, 47, 48, 49, 50, 56, 57, 58, 60, 62, 63, 64]
Generated sequence:  [0, 1, 3, 4, 5, 9, 10, 11, 14, 15, 16, 22, 24, 25, 26, 27, 31, 34, 38, 42, 43, 44, 47, 48, 49, 50, 51, 56, 57, 58, 60, 62, 63, 64]
MCDP sequence:       [0, 1, 3, 4, 9, 14, 15, 22, 24, 25, 26, 31, 34, 38, 42, 43, 44, 47, 48, 49, 50, 56, 57, 58, 60, 62, 63, 64]
Embedding shape:  torch.Size([1, 1024])
Target sequence:     [0, 3, 5, 6, 7, 12, 14, 16, 18, 20, 23, 25, 27, 28, 30, 31, 36, 37, 39, 42, 43, 45, 46, 47, 48, 51, 53, 55, 58, 63, 64]
Generated sequence:  [0, 1, 2, 3, 5, 6, 7, 8, 9, 12, 14, 16, 18, 20, 23, 25, 27, 28, 30, 31, 32, 36, 37, 39, 40, 42, 43, 45, 46, 47, 48, 51, 53, 55, 58, 59, 63, 64]
MCDP sequence:       [0, 3, 5, 6, 7, 12, 14, 16, 18, 20, 23, 25, 27, 28, 30, 31, 36, 37, 39, 42, 43, 45, 46, 47, 48, 51, 53, 55, 58, 63, 64]
Embedding shape:  torch.Size([1, 1024])
Target sequence:     [0, 1, 2, 3, 4, 6, 10, 21, 22, 23, 24, 25,

In [10]:
print("Performance test without MC dropout.")
acc, mean, std = performance(model, dataloader, num_samples=100)
print("Accuracy:      ", acc)
print("Mean distance: ", mean)
print("Std:           ", std)

n_mc_samples = [2,3,5,30]
for i in n_mc_samples:
    print(f"Performance test with MC dropout and {i} mcdp runs per sample.")
    acc, mean, std = performance(model, dataloader, n_mc_samples=i, num_samples=100)
    print("Accuracy:      ", acc)
    print("Mean distance: ", mean)
    print("Std:           ", std)
    print()



Performance test without MC dropout.


Testing performance with MCDP disabled...: 100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


Accuracy:       0.0
Mean distance:  7.75
Std:            2.68467875173176
Performance test with MC dropout and 2 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [00:31<00:00,  3.13it/s]


Accuracy:       0.05
Mean distance:  3.52
Std:            2.104661492972207

Performance test with MC dropout and 3 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [00:43<00:00,  2.32it/s]


Accuracy:       0.35
Mean distance:  1.55
Std:            1.6635804759614128

Performance test with MC dropout and 5 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [01:06<00:00,  1.50it/s]


Accuracy:       0.52
Mean distance:  0.96
Std:            1.4347125147568764

Performance test with MC dropout and 30 mcdp runs per sample.


Testing performance with MCDP enabled...: 100%|██████████| 100/100 [06:08<00:00,  3.69s/it]

Accuracy:       0.73
Mean distance:  0.43
Std:            0.8155366331440912






In [None]:
#save_model(model, "model_Q6_01.pth", optimizer, epoch=50)

#model_re = CustomDecoderModel(config_path)
#model_re, optimizer_re, epoch_re= load_model(model_re, "model_Q6_00.pth")

#model = model_re
#model = model.to(model.config["device"])