In [1]:
import os
import sys
import getopt
import copy
import json
from importlib import util

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision
import torchaudio

import tensorkrowch as tk
from tensorkrowch.decompositions import tt_rss

In [2]:
cwd = os.path.join(os.getcwd(), '..', '..')
p_english_list = [0.005, 0.01, 0.05,
                 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                 0.95, 0.99, 0.995]
out_rate = 1000

In [3]:
def import_file(full_name, path):
    """Returns a python module given its path"""
    spec = util.spec_from_file_location(full_name, path)
    mod = util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod

In [4]:
class CustomCommonVoice(Dataset):
    """
    Class for the (imbalanced) datasets created.

    Parameters
    ----------
    p_english : float (p_english_list)
        Proportion of audios of people with english accent in the dataset.
    idx : int [0, 9]
        Index of the annotations to be used. For each ``p_english`` there are 10
        datasets.
    set : str
        Indicates which dataset is to be loaded.
    transform : torchvision.transforms
        Transformations of the dataset (data augmentation, normalization, etc.)
    target_transform : func
        Transformation of the target attribute (not used).
    """
    
    def __init__(self,
                 p_english,
                 idx,
                 set="train_df.tsv",
                 transform=None):
        
        global p_english_list
        if (p_english not in p_english_list) or ((idx < 0) or (idx > 9)):
            raise ValueError(
                f'`p_english` can only take values within {p_english_list}, '
                f'and `idx` should be between 0 and 9')
        
        global cwd
        self.dataset = torchaudio.datasets.COMMONVOICE(
            root=os.path.join(cwd, 'CommonVoice'),
            tsv=os.path.join('datasets', str(p_english), str(idx), set))
        self.transform = transform

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        x, y, z = self.dataset[index]
        if self.transform:
            x = self.transform((x, y))
        return x, int(z['sex'])
    

def resample(x):
    global out_rate
    x, in_rate = x
    resample_trans = torchaudio.transforms.Resample(in_rate, out_rate)
    return resample_trans(x)

def crop(x):
    global out_rate
    llimit = (x.size(1) // 2 - out_rate // 2)
    rlimit = (x.size(1) // 2 + out_rate // 2)
    x = x[:, llimit:rlimit].flatten()
    if x.size(0) < out_rate:
        return None
    return x

def rfft(x):
    if x is None:
        return None
    return torch.fft.rfft(x)[:-1].abs()

def normalize(x):
    x = x / 200
    x = torch.where(x <= 0, 1e-5, x)
    x = torch.where(x >= 1, 1 - 1e-5, x)
    return x

transform = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(resample),
    torchvision.transforms.Lambda(crop),
    torchvision.transforms.Lambda(rfft),
    torchvision.transforms.Lambda(normalize)
    ])

def none_collate(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


def load_data(p_english, idx, batch_size):
    """Loads dataset performing the required transformations for train or test."""
    
    # Load datasets
    global transform
    train_dataset = CustomCommonVoice(p_english,
                                      idx,
                                      set="train_df.tsv",
                                      transform=transform)
    val_dataset = CustomCommonVoice(p_english,
                                    idx,
                                    set="val_df.tsv",
                                    transform=transform)
    test_dataset = CustomCommonVoice(p_english,
                                     idx,
                                     set="test_df.tsv",
                                     transform=transform)
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              collate_fn=none_collate,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            collate_fn=none_collate,
                            shuffle=False)
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             collate_fn=none_collate,
                             shuffle=False)
    
    return train_loader, val_loader, test_loader


def load_sketch_samples(p_english, idx, batch_size):
    """Loads sketch samples to tensorize models."""
    
    # Load datasets
    global transform
    test_tensorize_dataset = CustomCommonVoice(p_english,
                                               idx,
                                               set="test_df_tensorize.tsv",
                                               transform=transform)
    test_unused_dataset = CustomCommonVoice(p_english,
                                            idx,
                                            set="test_df_unused.tsv",
                                            transform=transform)
    
    # Create DataLoaders
    test_tensorize_loader = DataLoader(test_tensorize_dataset,
                                       batch_size=500,
                                       collate_fn=none_collate,
                                       shuffle=False)
    test_unused_loader = DataLoader(test_unused_dataset,
                                    batch_size=batch_size,
                                    collate_fn=none_collate,
                                    shuffle=False)
    
    return test_tensorize_loader, test_unused_loader

In [5]:
def training_epoch(device, model, criterion, optimizer, train_loader,
                   n_batches=None):
    if n_batches is not None:
        n_batches = min(n_batches, len(train_loader))
    else:
        n_batches = len(train_loader)
    i = 0
    
    model.train()
    for data, labels in train_loader:
        data = data.to(device)
        labels = labels.to(device)
        
        scores = model(data)
        loss = criterion(scores, labels)
        
        _, preds = torch.max(scores, 1)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient descent
        optimizer.step()
        
        i += 1
        if i >= n_batches:
            break
    
    return model


def training_epoch_tn(device, model, embedding, renormalize,
                      criterion, optimizer, train_loader, logs, n_batches=None):
    print_each = len(train_loader) // 10
    
    if n_batches is not None:
        n_batches = min(n_batches, len(train_loader))
    else:
        n_batches = len(train_loader)
    i = 0

    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device)
        
        # Forward
        scores = model(embedding(data),
                       inline_input=False,
                       inline_mats=False,
                       renormalize=renormalize)
        scores = scores.pow(2)
        scores = scores / scores.norm(dim=1, keepdim=True)
        scores = torch.where(scores == 0, 1e-10, scores)
        scores = scores.log()
        
        loss = criterion(scores, labels)
        
        with torch.no_grad():
            _, preds = torch.max(scores, 1)
            accuracy = (preds == labels).float().mean().item()
            
            logs['train_losses'].append(loss.item())
            logs['train_accs'].append(accuracy)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient descent
        optimizer.step()
        
        if ((batch_idx + 1) % print_each == 0):
            print(f'\tBatch: {batch_idx + 1}/{len(train_loader)}, '
                  f'Last Train Loss: {loss.item():.3f}, '
                  f'Last Train Acc: {accuracy:.3f}')
        
        i += 1
        if i >= n_batches:
            break
    
    return model, logs


def test_tn(device, model, embedding, renormalize, criterion, test_loader,
            logs, n_batches=None):
    """Computes accuracy on test set."""
    running_loss = 0
    running_acc = 0
    
    if n_batches is not None:
        n_batches = min(n_batches, len(test_loader))
    else:
        n_batches = len(test_loader)
    i = 0
    
    model.eval()
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device)
            labels = labels.to(device)
            
            scores = model(embedding(data),
                           inline_input=False,
                           inline_mats=False,
                           renormalize=renormalize)
            scores = scores.pow(2)
            scores = scores / scores.norm(dim=1, keepdim=True)
            scores = torch.where(scores == 0, 1e-10, scores)
            scores = scores.log()
            
            loss = criterion(scores, labels)
            
            _, preds = torch.max(scores, 1)
            accuracy = (preds == labels).float().mean().item()
            running_acc += accuracy
            running_loss += loss.item()
            
            i += 1
            if i >= n_batches:
                break
    
    logs['val_losses'].append(running_loss / n_batches)
    logs['val_accs'].append(running_acc / n_batches)
    
    return logs

In [6]:
def train_tn(init_method='rss',
             embedding_fn='poly',
             renormalize=False,
             bond_dim=5,
             n_epochs=5):
    """Retrains best MPS models"""
    print(init_method, embedding_fn, renormalize, bond_dim, n_epochs)
    
    global p_english_list
    
    device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')
    
    # Tensorization hyperprameters
    n_features = out_rate // 2 + 1
    embed_dim = 2
    
    if embedding_fn == 'poly':
        def embedding(x):
            x = tk.embeddings.poly(x, degree=embed_dim - 1)
            return x
        
    elif embedding_fn == 'unit':
        def embedding(x):
            x = tk.embeddings.unit(x, dim=embed_dim)
            return x
    
    elif embedding_fn == 'basis':
        def embedding(x):
            x = tk.embeddings.discretize(x, base=embed_dim, level=1).squeeze(-1).int()
            x = tk.embeddings.basis(x, dim=embed_dim).float() # batch x n_features x dim
            return x
    
    elif embedding_fn == 'fourier':
        def embedding(x):
            x = tk.embeddings.fourier(x, dim=embed_dim)
            return x
    
    if init_method.startswith('rss'):
        
        softmax = nn.Softmax(dim=1)
        
        if init_method == 'rss':
            # Initialize from tensorization of fffc_tiny model
            aux_mod = import_file('model', os.path.join(cwd, 'models', 'fffc_tiny.py'))
            model_class = aux_mod.Model
            
            models_dir = os.path.join(cwd, 'results', '0_train_nns',
                                    model_class.name, '0.5', '0')
            
            # Check tuned config of balanced model
            config_dir = os.path.join(cwd, 'results', '0_train_nns',
                                    model_class.name, '0.5')
            with open(os.path.join(config_dir, 'tuned_config.json'), 'r') as f:
                config = json.load(f)
            
            state_dict_dir = os.listdir(models_dir)[0]
            state_dict = torch.load(os.path.join(models_dir, state_dict_dir),
                                    weights_only=False)
            
            # Initialize model with balanced config
            model = model_class(config)
            model.load_state_dict(state_dict)
            model.eval()
            model.to(device)
            
            def fn(samples):
                return softmax(model(samples)).sqrt()
            
            n_samples = 50
        
        elif init_method == 'rss_random':
            
            def fn(samples):
                result = torch.randn(samples.size(0), 2).to(device)
                return softmax(result).sqrt()
            
            n_samples = 10
        
        elif init_method == 'rss_pretrain':
            # Initialize from tensorization of fffc_tiny model
            aux_mod = import_file('model', os.path.join(cwd, 'models', 'fffc_tiny.py'))
            model_class = aux_mod.Model
            
            # Check tuned config of balanced model
            config_dir = os.path.join(cwd, 'results', '0_train_nns',
                                    model_class.name, '0.5')
            with open(os.path.join(config_dir, 'tuned_config.json'), 'r') as f:
                config = json.load(f)
            
            # Initialize model with balanced config
            model = model_class(config)
            model.to(device)
            
            # Load data
            batch_size = 32
            train_loader, val_loader, test_loader = load_data(0.5, 0, batch_size)
            
            # Pre-train model for a few epochs
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=1e-2,
                                         weight_decay=1e-5)
            
            model = training_epoch(device=device,
                                   model=model,
                                   criterion=criterion,
                                   optimizer=optimizer,
                                   train_loader=train_loader,
                                   n_batches=50,
                                   )
            
            print('* Finished pre-training')
            
            def fn(samples):
                return softmax(model(samples)).sqrt()
            
            n_samples = 10
        
        # Load data
        batch_size_loader = 500
        tensorize_loader, unused_loader = load_sketch_samples(
            0.5, 0, batch_size_loader)
        sketch_samples, sketch_labels = next(iter(tensorize_loader))
        perm = torch.randperm(sketch_samples.size(0))
        sketch_samples = sketch_samples[perm]
        sketch_labels = sketch_labels[perm]
        
        # Tensorization hyperprameters
        cum_percentage = 0.99
        domain_multiplier = 2
        
        domain_dim = domain_multiplier * embed_dim
        domain = torch.arange(domain_dim).float() / domain_dim
        
        batch_size_tensorize = 1000
        
        # Tensorize
        cores = tt_rss(function=fn,
                       embedding=embedding,
                       sketch_samples=sketch_samples[:n_samples],
                       labels=sketch_labels[:n_samples],
                       domain=domain,
                       domain_multiplier=domain_multiplier,
                       rank=bond_dim,
                       cum_percentage=cum_percentage,
                       batch_size=batch_size_tensorize,
                       device=device,
                       verbose=False)
        
        print('* Finished tensorization')
        
        # MPS model
        tn_model = tk.models.MPSLayer(tensors=cores)
        tn_model.to(device)
        
        tn_model.trace(
            torch.zeros(1, n_features - 1, embed_dim).to(device),
            inline_input=False,
            inline_mats=False,
            renormalize=renormalize
        )
    
    else:
        # MPS model
        tn_model = tk.models.MPSLayer(n_features=n_features,
                                      in_dim=embed_dim,
                                      out_dim=2,
                                      bond_dim=bond_dim,
                                      init_method=init_method,
                                      std=1e-5)
        tn_model.to(device)
        
        tn_model.trace(
            torch.zeros(1, n_features - 1, embed_dim).to(device),
            inline_input=False,
            inline_mats=False,
            renormalize=renormalize
        )
    
    # Load data
    batch_size = 32
    train_loader, val_loader, _ = load_data(0.5, 0, batch_size)
    
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(tn_model.parameters(),
                                 lr=1e-5,
                                 weight_decay=1e-10)
    
    logs = {'train_losses': [],
            'val_losses': [],
            'train_accs': [],
            'val_accs': []}
    
    for epoch in range(n_epochs):
        tn_model, logs = training_epoch_tn(
            device=device,
            model=tn_model,
            embedding=embedding,
            renormalize=renormalize,
            criterion=criterion,
            optimizer=optimizer,
            train_loader=train_loader,
            logs=logs,
            # n_batches=5,
            )
        
        logs = test_tn(device=device,
                       model=tn_model,
                       embedding=embedding,
                       renormalize=renormalize,
                       criterion=criterion,
                       test_loader=val_loader,
                       logs=logs,
                       # n_batches=5
                       )
        
        print(f'**Epoch: {epoch + 1}/{n_epochs}** => '
              f'Train Loss: {logs["train_losses"][-1]:.3f}, '
              f'Val Loss: {logs["val_losses"][-1]:.3f}, '
              f'Train Acc: {logs["train_accs"][-1]:.3f}, '
              f'Val Acc: {logs["val_accs"][-1]:.3f}')
        print('\t', init_method, embedding_fn, renormalize, bond_dim)
    
    results_dir = os.path.join(cwd, 'results', '5_initialization')
    
    torch.save(
        (logs, tn_model.tensors),
        os.path.join(results_dir,
                     f'{init_method}_{embedding_fn}_{renormalize}_'
                     f'{bond_dim}_{n_epochs}.pt')
    )

In [33]:
global p_english_list
    
device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')

# Tensorization hyperprameters
n_features = out_rate // 2 + 1
embed_dim = 2
bond_dim = 5

# def embedding(x):
#     x = tk.embeddings.poly(x, degree=embed_dim - 1)
#     return x

def embedding(x):
    x = tk.embeddings.unit(x, dim=embed_dim)
    return x

# def embedding(x):
#     x = tk.embeddings.discretize(x, base=embed_dim, level=1).squeeze(-1).int()
#     x = tk.embeddings.basis(x, dim=embed_dim).float() # batch x n_features x dim
#     return x

# def embedding(x):
#     x = tk.embeddings.fourier(x, dim=embed_dim)
#     return x

# def embedding(x):
#     return x


softmax = nn.Softmax(dim=1)
# m = torch.randn(n_features - 1, 2, device=device) / 2
# m = torch.randn(n_features - 1, embed_dim, 2, device=device) / 2

def fn(samples):
    result = torch.randn(samples.size(0), 2).to(device)
    # result = torch.einsum('neo,bne->bo', m, samples)
    # result = samples.to(device) @ m
    result = softmax(result).sqrt()
    return result

n_samples = 10
    
# Load data
# sketch_samples = torch.randn(1000, n_features - 1, embed_dim)

# Load data
batch_size_loader = 500
tensorize_loader, unused_loader = load_sketch_samples(
    0.5, 0, batch_size_loader)
sketch_samples, sketch_labels = next(iter(tensorize_loader))
perm = torch.randperm(sketch_samples.size(0))
sketch_samples = sketch_samples[perm]
sketch_labels = sketch_labels[perm]

# Tensorization hyperprameters
cum_percentage = 0.99
domain_multiplier = 1

domain_dim = domain_multiplier * embed_dim
domain = torch.arange(domain_dim).float() / domain_dim

batch_size_tensorize = 1000

# Tensorize
cores = tt_rss(function=fn,
               embedding=embedding,
               sketch_samples=sketch_samples[:n_samples],
               labels=sketch_labels[:n_samples],
               domain=domain,
               domain_multiplier=domain_multiplier,
               rank=bond_dim,
               cum_percentage=cum_percentage,
               batch_size=batch_size_tensorize,
               device=device,
               verbose=False)

print('* Finished tensorization')

# MPS model
tn_model = tk.models.MPSLayer(tensors=cores)
tn_model.to(device)

tn_model.trace(
    torch.zeros(1, n_features - 1, embed_dim).to(device),
    inline_input=False,
    inline_mats=False
)

* Finished tensorization


In [34]:
for c in cores:
    print(c.norm())

tensor(1.4142)
tensor(158.9334)
tensor(2.9490)
tensor(2.9675)
tensor(4.2445)
tensor(3.3539)
tensor(3.6740)
tensor(3.3155)
tensor(3.8516)
tensor(3.5697)
tensor(2.9288)
tensor(3.9666)
tensor(4.4011)
tensor(4.6749)
tensor(4.3017)
tensor(4.0500)
tensor(3.9235)
tensor(3.3399)
tensor(3.4139)
tensor(2.5714)
tensor(4.0998)
tensor(3.8676)
tensor(3.5947)
tensor(3.1668)
tensor(4.9416)
tensor(3.6019)
tensor(2.9243)
tensor(4.7549)
tensor(3.7685)
tensor(3.8942)
tensor(3.1023)
tensor(4.8769)
tensor(3.8876)
tensor(3.3290)
tensor(4.0792)
tensor(2.4973)
tensor(5.6227)
tensor(2.8766)
tensor(3.7875)
tensor(3.6727)
tensor(3.1738)
tensor(4.6243)
tensor(3.2961)
tensor(3.4429)
tensor(4.5232)
tensor(3.9021)
tensor(3.5805)
tensor(3.6043)
tensor(3.4333)
tensor(4.5634)
tensor(3.2855)
tensor(3.5899)
tensor(2.8971)
tensor(3.1691)
tensor(2.9947)
tensor(4.3825)
tensor(3.7128)
tensor(3.5671)
tensor(3.2088)
tensor(3.4340)
tensor(3.8380)
tensor(2.8607)
tensor(4.0594)
tensor(2.9803)
tensor(3.3131)
tensor(3.2838)
tensor(3

In [35]:
for c in cores:
    print(c.mean(), c.std())

tensor(0.4883) tensor(0.5905)
tensor(0.9876) tensor(41.0238)
tensor(0.0615) tensor(0.4681)
tensor(0.0183) tensor(0.4235)
tensor(0.0877) tensor(0.5998)
tensor(0.0834) tensor(0.4717)
tensor(-0.0128) tensor(0.5247)
tensor(0.1388) tensor(0.4524)
tensor(0.0174) tensor(0.5500)
tensor(-0.0105) tensor(0.5099)
tensor(-0.0338) tensor(0.4170)
tensor(0.0242) tensor(0.5661)
tensor(0.0331) tensor(0.6278)
tensor(0.0132) tensor(0.6677)
tensor(0.0387) tensor(0.6133)
tensor(0.0053) tensor(0.5785)
tensor(0.0779) tensor(0.5550)
tensor(0.0696) tensor(0.4719)
tensor(0.0116) tensor(0.4876)
tensor(0.0240) tensor(0.3665)
tensor(-0.0503) tensor(0.5835)
tensor(0.0223) tensor(0.5521)
tensor(-0.0681) tensor(0.5089)
tensor(-0.0245) tensor(0.4517)
tensor(0.1339) tensor(0.6929)
tensor(-0.0236) tensor(0.5140)
tensor(0.0464) tensor(0.4151)
tensor(0.0332) tensor(0.6784)
tensor(-0.0336) tensor(0.5373)
tensor(0.0235) tensor(0.5558)
tensor(-0.0671) tensor(0.4380)
tensor(0.0111) tensor(0.6966)
tensor(0.0175) tensor(0.5551)


In [36]:
result = tn_model(
    embedding(torch.randn(1000, n_features - 1).to(device)),
    inline_input=False,
    inline_mats=False
)

result[:, 0].mean(), result[:, 0].std(), result[:, 1].mean(), result[:, 1].std()

(tensor(-2.0958e+08, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(4.0423e+09, device='cuda:0', grad_fn=<StdBackward0>),
 tensor(29129898., device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(2.1941e+09, device='cuda:0', grad_fn=<StdBackward0>))

In [37]:
result = tn_model(
    embedding(sketch_samples.to(device)),
    inline_input=False,
    inline_mats=False
)

result[:, 0].mean(), result[:, 0].std(), result[:, 1].mean(), result[:, 1].std()

(tensor(0.0875, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.1879, device='cuda:0', grad_fn=<StdBackward0>),
 tensor(0.0936, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.2012, device='cuda:0', grad_fn=<StdBackward0>))

In [38]:
result

tensor([[0.0639, 0.0684],
        [0.3045, 0.3233],
        [0.1356, 0.1460],
        [0.0324, 0.0347],
        [0.0183, 0.0195],
        [0.0300, 0.0321],
        [0.0474, 0.0509],
        [0.1243, 0.1335],
        [0.0514, 0.0551],
        [0.2680, 0.2858],
        [0.0092, 0.0099],
        [0.1012, 0.1085],
        [0.2637, 0.2818],
        [0.1548, 0.1664],
        [0.0238, 0.0256],
        [0.0120, 0.0130],
        [0.0713, 0.0765],
        [0.0648, 0.0691],
        [0.0576, 0.0620],
        [0.1269, 0.1354],
        [0.0608, 0.0654],
        [0.1099, 0.1181],
        [0.0588, 0.0629],
        [0.0330, 0.0355],
        [0.0195, 0.0209],
        [0.6417, 0.6886],
        [0.0259, 0.0277],
        [0.0171, 0.0182],
        [0.7481, 0.8053],
        [0.3266, 0.3410],
        [0.0221, 0.0237],
        [0.1138, 0.1213],
        [0.2060, 0.2205],
        [0.0105, 0.0113],
        [0.0154, 0.0166],
        [0.0599, 0.0643],
        [0.0187, 0.0201],
        [0.0128, 0.0138],
        [0.0

In [15]:
tn_model = tk.models.MPSLayer(n_features=n_features,
                              in_dim=embed_dim,
                              out_dim=2,
                              bond_dim=bond_dim,
                              init_method='randn_eye',
                              std=1e-5)
tn_model.to(device)

tn_model.trace(
    torch.zeros(1, n_features - 1, embed_dim).to(device),
    inline_input=False,
    inline_mats=False
)

In [16]:
for t in tn_model.tensors:
    print(t.norm().detach(), t.mean().detach(), t.std().detach())

tensor(1.0000, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3162, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0.1000, device='cuda:0') tensor(0.3030, device='cuda:0')
tensor(2.2361, device='cuda:0') tensor(0

In [None]:

# Load data
batch_size = 32
train_loader, val_loader, _ = load_data(0.5, 0, batch_size)

criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(tn_model.parameters(),
                                lr=1e-5,
                                weight_decay=1e-10)

logs = {'train_losses': [],
        'val_losses': [],
        'train_accs': [],
        'val_accs': []}

for epoch in range(n_epochs):
    tn_model, logs = training_epoch_tn(
        device=device,
        model=tn_model,
        embedding=embedding,
        renormalize=renormalize,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        logs=logs,
        # n_batches=5,
        )
    
    logs = test_tn(device=device,
                    model=tn_model,
                    embedding=embedding,
                    renormalize=renormalize,
                    criterion=criterion,
                    test_loader=val_loader,
                    logs=logs,
                    # n_batches=5
                    )
    
    print(f'**Epoch: {epoch + 1}/{n_epochs}** => '
            f'Train Loss: {logs["train_losses"][-1]:.3f}, '
            f'Val Loss: {logs["val_losses"][-1]:.3f}, '
            f'Train Acc: {logs["train_accs"][-1]:.3f}, '
            f'Val Acc: {logs["val_accs"][-1]:.3f}')
    print('\t', init_method, embedding_fn, renormalize, bond_dim)

results_dir = os.path.join(cwd, 'results', '5_initialization')