# imports

In [1]:
!pip install wandb torchsummaryX mne transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.8/252.8 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [14]:
### Torch
import  torch
import  torch.nn as nn
import  torch.nn.functional as F
from    torch.optim import lr_scheduler
from    torchsummaryX import summary
from    torch.utils.data import Dataset, DataLoader, random_split
import  torchaudio
import  torchaudio.transforms as tat
from    torch.nn.utils.rnn import (pad_sequence, pack_padded_sequence,
                                   pad_packed_sequence)

### General
import  random
import  numpy as np
import  pandas as pd
import  pickle
import  scipy
import  gc
import  re
from    tqdm.auto import tqdm
import  os
import  datetime
import  time
import  wandb
import  matplotlib.pyplot as plt
import  seaborn as sns

# wav2vec2 and EEG processing
from    transformers import (AutoProcessor, AutoModelForPreTraining,
                             CLIPProcessor, CLIPModel)
import  mne

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

Device:  cpu


In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [28]:
DATA_PATH  = '/content/gdrive/MyDrive/11785-IDLf23/Final project/0_Data/'

BRAIN_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2')
AUDIO_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2/audio')
PPROC_PATH = os.path.join(DATA_PATH, 'brennan_and_hale_v2/proc/timelock-preprocessing')

EEG_PATH   = os.path.join(DATA_PATH, 'eeg-segments')
EMBED_PATH = os.path.join(DATA_PATH, 'brennan_wav2vec2_embeddings')

PRETRAINED_PATH = '/content/gdrive/MyDrive/11785-IDLf23/Final project/pretrained-models'

# data [QA]

## dataset definition

In [5]:
class BrainAudioDataset(torch.utils.data.Dataset):
    """
    Hybrid Memory Efficient Dataset.
    Loads all audio embeddings in __init__(). ~1-2min.
    Loads eeg in __getitem__().
    Loading everything in __init__() crashes.
    Loading everything in __getitem__() is super slow.

    Load eeg
    Pad (add AUD of zeros) to be 62 chan
    Mask the bad chans
    """

    def __init__(self, eeg_path=EEG_PATH, embed_path=EMBED_PATH, transforms=None):
        super().__init__()

        self.eeg_root = eeg_path
        self.embed_root = embed_path
        self.transforms = transforms

        eeg_fnames = sorted(os.listdir(self.eeg_root))
        embed_fnames = sorted(os.listdir(self.embed_root))

        gc.collect()

        self.audio_embeddings = {}
        for segment_idx, fname in enumerate(tqdm(embed_fnames), start=1):
            audio_fpath = os.path.join(self.embed_root, fname)
            audio = torch.load(audio_fpath)
            audio_embed = audio.hidden_states[-1]
            audio_embed = audio_embed.squeeze(0)
            audio_embed = audio_embed.to(DEVICE)
            self.audio_embeddings[segment_idx] = audio_embed
            del audio, audio_embed
            gc.collect()

        self.eegs = {}
        for idx, fname in enumerate(tqdm(eeg_fnames)):
            # Extract subject id and segment (to be returned in __getitem__())
            subject_idx, segment_idx = self.extract_info(fname)

            # Load eeg filepath
            eeg_fpath = os.path.join(self.eeg_root, fname)
            self.eegs[idx] = (subject_idx, int(segment_idx), eeg_fpath)

        self.length = len(self.eegs)

        # montage = mne.channels.make_standard_montage("easycap-M10")

    def __len__(self):

        return self.length

    def __getitem__(self, idx):
        """
        - eeg           : (1, T_eeg, C_eeg)
        - audio_embed   : (1, T_audio, 1024)
        """
        # Get eeg
        _, segment_idx, eeg_fpath = self.eegs[idx]
        eeg = np.load(eeg_fpath)
        eeg = torch.tensor(eeg.transpose(), device=DEVICE)

        # If nchan < 62, zero pad for the audio channel (should be done in preproc)

        # Remove bad chan (should be done in prepoc)

        # Retrieve pre-loaded audio embedding
        audio_embed = self.audio_embeddings[segment_idx]
        return eeg, audio_embed, len(eeg), len(audio_embed)

    def collate_fn(self, batch):

        eeg, audio_embed, len_eeg, len_audio_embed = zip(*batch)

        # pack the mfccs and transcripts using the pad_sequence function from pytorch
        batch_eeg_pad = pad_sequence(eeg, batch_first=True)
        batch_audio_embed_pad = pad_sequence(audio_embed, batch_first=True)
        del eeg, audio_embed

        # Apply transformations
        if self.transforms is not None:
            batch_eeg_pad = self.transforms(batch_eeg_pad)

        return (batch_eeg_pad,
                batch_audio_embed_pad,
                torch.tensor(len_eeg, dtype=torch.int64),
                torch.tensor(len_audio_embed, dtype=torch.int64))

    def read_sfp(self, file_path):
        """
        Reads a BESA SFP (Surface Point) file (locations of sensors)
        """

        electrodes = {}
        with open(file_path, 'r') as file:
            for line in file:
                parts = line.strip().split()  # Split by whitespace
                if len(parts) == 4:
                    # Parse the electrode name and coordinates
                    name = parts[0]
                    x, y, z = map(float, parts[1:])  # Convert strings to floats
                    electrodes[name] = (x, y, z)

        return electrodes

    def extract_info(self, fname):
        # This pattern looks for any text (non-digits) followed by digits, a hyphen, and more digits
        match = re.match(r'([A-Za-z]+[0-9]+)-([0-9]+).npy', fname)
        if match:
            return match.group(1), match.group(2)
        else:
            return None


## dataloader definition

## train/test split

In [18]:
full_dataset = BrainAudioDataset()

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/348 [00:00<?, ?it/s]

In [25]:
TRAIN_PORTION = 0.8
VAL_PORTION   = 0.1
TEST_PORTION  = 0.1

# Splitting the dataset
torch.manual_seed(1)
train_data, val_data, test_data = random_split(full_dataset, [TRAIN_PORTION, VAL_PORTION, TEST_PORTION])

train_loader = torch.utils.data.DataLoader(
    dataset      = train_data,
    batch_size   = 8,
    shuffle      = True,
    drop_last    = False,
    num_workers  = 2,
    pin_memory   = True,
    collate_fn   = full_dataset.collate_fn
)

val_loader = torch.utils.data.DataLoader(
    dataset      = val_data,
    batch_size   = 8,
    shuffle      = False,
    drop_last    = False,
    num_workers  = 2,
    pin_memory   = True,
    collate_fn   = full_dataset.collate_fn
)

test_loader = torch.utils.data.DataLoader(
    dataset      = test_data,
    batch_size   = 8,
    shuffle      = False,
    drop_last    = False,
    num_workers  = 2,
    pin_memory   = True,
    collate_fn   = full_dataset.collate_fn
)

In [27]:
print('-'*80)
print(f'Len full data:      {len(full_dataset)}')
print(f'Len train data:     {len(full_dataset)}')
print(f'Len val data:       {len(full_dataset)}')
print(f'Len test data:      {len(full_dataset)}')
print('-'*80)
print(f'Len Train Loader:   {train_loader.__len__()}')
print(f'Len Val Loader:     {val_loader.__len__()}')
print(f'Len Test Loader:    {test_loader.__len__()}')

gc.collect()

for batch in train_loader:
    eeg, audio_embedding, l_eeg, l_audio = batch
    print(eeg.shape, audio_embedding.shape, l_eeg.shape, l_audio.shape)
    print(eeg.dtype, audio_embedding.dtype, l_eeg.dtype, l_audio.dtype)
    del eeg, audio_embedding, l_eeg, l_audio
    gc.collect()
    break

for batch in val_loader:
    eeg, audio_embedding, l_eeg, l_audio = batch
    print(eeg.shape, audio_embedding.shape, l_eeg.shape, l_audio.shape)
    print(eeg.dtype, audio_embedding.dtype, l_eeg.dtype, l_audio.dtype)
    del eeg, audio_embedding, l_eeg, l_audio
    gc.collect()
    break

for batch in test_loader:
    eeg, audio_embedding, l_eeg, l_audio = batch
    print(eeg.shape, audio_embedding.shape, l_eeg.shape, l_audio.shape)
    print(eeg.dtype, audio_embedding.dtype, l_eeg.dtype, l_audio.dtype)
    del eeg, audio_embedding, l_eeg, l_audio
    gc.collect()
    break

--------------------------------------------------------------------------------
Len full data:      348
Len train data:     348
Len val data:       348
Len test data:      348
--------------------------------------------------------------------------------
Len Train Loader:   35
Len Val Loader:     5
Len Test Loader:    5
torch.Size([8, 35022, 62]) torch.Size([8, 3499, 1024]) torch.Size([8]) torch.Size([8])
torch.float64 torch.float32 torch.int64 torch.int64
torch.Size([8, 33172, 62]) torch.Size([8, 3313, 1024]) torch.Size([8]) torch.Size([8])
torch.float64 torch.float32 torch.int64 torch.int64
torch.Size([8, 31915, 62]) torch.Size([8, 3499, 1024]) torch.Size([8]) torch.Size([8])
torch.float64 torch.float32 torch.int64 torch.int64


# models [KS]

In [None]:
class SubjectLayers(nn.Module): # learn how different are the subjects and based on that normalize the EEG
    def __init__(self, in_channels, out_channels, n_subjects):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels)) # initialize weights for each subject
        self.weights.data *= 1 / in_channels**0.5 # normalize weights

    def forward(self, x, n_subjects):
        _, C, D = self.weights.shape
        subject_weights = self.weights.gather(0, n_subjects.view(-1, 1, 1).expand(-1, C, D)) # select the appropriate weights for each subject in the batch
        transformed_eeg = torch.einsum("bct,bcd->bdt", x, subject_weights) # apply the subject-specific transformations

        return transformed_eeg

class ScaledEmbedding(nn.Module): # assign a unique vector to each subject (similar to positional embedding)

    def __init__(self, n_subjects, embedding_dim, scale):

        super().__init__()
        self.embedding = nn.Embedding(n_subjects, embedding_dim)
        self.embedding.weight.data /= scale
        self.scale = scale

    def forward(self, x):
        scaled_embedding = self.embedding(x) * self.scale
        return scaled_embedding

class LayerScale(nn.Module):
    def __init__(self, channels, init = 0.1, boost = 5.):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
        self.scale.data[:] = init / boost
        self.boost = boost

    def forward(self, x):
        return (self.boost * self.scale[:, None]) * x

class ConvSequence(nn.Module):

    def __init__(self,
                 channels = [16, 32, 64, 128],
                 kernel = 4, dilation_growth = 1, dilation_period = None,
                 stride = 2, dropout = 0.0, leakiness = 0.0,
                 groups = 1, decode = False, batch_norm = False,
                 dropout_input = 0, skip = False, scale = None,
                 rewrite = False, activation_on_last = True,
                 post_skip = False, glu = 0, glu_context = 0,
                 glu_glu = True, activation = None):

        super().__init__()
        dilation = 1
        channels = tuple(channels)
        self.skip = skip
        self.sequence = nn.ModuleList()
        self.glus = nn.ModuleList()
        if activation is None:
            activation = partial(nn.LeakyReLU, leakiness)
        Conv = nn.Conv1d if not decode else nn.ConvTranspose1d
        # build layers
        for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
            layers: tp.List[nn.Module] = []
            is_last = k == len(channels) - 2

            # Set dropout for the input of the conv sequence if defined
            if k == 0 and dropout_input:
                assert 0 < dropout_input < 1
                layers.append(nn.Dropout(dropout_input))

            # conv layer
            if dilation_growth > 1:
                assert kernel % 2 != 0 # supports only odd kernel with dilation
            if dilation_period and (k % dilation_period) == 0:
                dilation = 1
            pad = kernel // 2 * dilation
            layers.append(Conv(chin, chout, kernel, stride, pad,
                               dilation=dilation, groups=groups if k > 0 else 1))
            dilation *= dilation_growth
            # non-linearity
            if activation_on_last or not is_last:
                if batch_norm:
                    layers.append(nn.BatchNorm1d(num_features=chout))
                layers.append(activation())
                if dropout:
                    layers.append(nn.Dropout(dropout))
                if rewrite:
                    layers += [nn.Conv1d(chout, chout, 1), nn.LeakyReLU(leakiness)]
            if chin == chout and skip:
                if scale is not None:
                    layers.append(LayerScale(chout, scale))
                if post_skip:
                    layers.append(Conv(chout, chout, 1, groups=chout, bias=False))

            self.sequence.append(nn.Sequential(*layers))
            if glu and (k + 1) % glu == 0:
                ch = 2 * chout if glu_glu else chout
                act = nn.GLU(dim=1) if glu_glu else activation()
                self.glus.append(
                    nn.Sequential(
                        nn.Conv1d(chout, ch, 1 + 2 * glu_context, padding=glu_context), act))
            else:
                self.glus.append(None)

    def forward(self, x):
        for module_idx, module in enumerate(self.sequence):
            old_x = x
            x = module(x)
            if self.skip and x.shape == old_x.shape:
                x = x + old_x
            glu = self.glus[module_idx]
            if glu is not None:
                x = glu(x)
        return x

class Attention(nn.Module): # scaled dot product with relative position encoding and local attention
    def __init__(self, channels, radius = 50, heads = 4):
        super().__init__()
        assert channels % heads == 0
        self.content = nn.Conv1d(channels, channels, 1)
        self.query = nn.Conv1d(channels, channels, 1)
        self.key = nn.Conv1d(channels, channels, 1)
        self.embedding = nn.Embedding(radius * 2 + 1, channels // heads)
        weight = self.embedding.weight.data
        weight[:] = weight.cumsum(0) / torch.arange(1, len(weight) + 1).float().view(-1, 1).sqrt()
        self.heads = heads
        self.radius = radius
        self.bn = nn.BatchNorm1d(channels)
        self.fc = nn.Conv1d(channels, channels, 1)
        self.scale = nn.Parameter(torch.full([channels], 0.1))

    def forward(self, x):

        def _split(y):
            return y.view(y.shape[0], self.heads, -1, y.shape[2])

        content = _split(self.content(x))
        query = _split(self.query(x))
        key = _split(self.key(x))

        batch_size, _, dim, length = content.shape

        dots = torch.einsum("bhct,bhcs->bhts", query, key) # first index `t` is query, second index `s` is key.

        steps = torch.arange(length, device=x.device)
        relative = (steps[:, None] - steps[None, :])
        embs = self.embedding.weight.gather(0, self.radius + relative.clamp_(-self.radius, self.radius).view(-1, 1).expand(-1, dim))
        embs = embs.view(length, length, -1)
        dots += 0.3 * torch.einsum("bhct,tsc->bhts", query, embs)
        dots = torch.where(
            relative.abs() <= self.radius, dots, torch.tensor(-float('inf')).to(embs))

        weights = torch.softmax(dots, dim=-1)
        out = torch.einsum("bhts,bhcs->bhct", weights, content)
        out += 0.3 * torch.einsum("bhts,tsc->bhct", weights, embs)
        out = out.reshape(batch_size, -1, length)
        out = F.relu(self.bn(self.fc(out))) * self.scale.view(1, -1, 1)
        return out

"""
# TEST 1
def test_subject_layers():
    batch_size = 2
    in_channels = 3
    out_channels = 2
    n_subjects = 1
    time_steps = 2
    eeg_data = torch.randn(batch_size, in_channels, time_steps)
    print(eeg_data)
    subjects = torch.randint(0, n_subjects, (batch_size,))
    print(subjects)
    subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)
    output = subject_layers(eeg_data, subjects)
    expected_shape = (batch_size, out_channels, time_steps)
    assert output.shape == expected_shape, f"Output shape mismatch: expected {expected_shape}, got {output.shape}"
test_subject_layers()

# TEST 2
def scaled_embedding():
  n_subjects = 10
  embedding_dim = 5
  scale = 10
  scaled_embedding = ScaledEmbedding(n_subjects, embedding_dim, scale)
  subject_indices = torch.tensor([1, 2])
  embeddings = scaled_embedding(subject_indices)
  print(embeddings)
scaled_embedding()

# TEST 3
conv_sequence = ConvSequence(channels=[16, 32, 64, 128])
input_tensor = torch.randn(1, 16, 50) # (Batch Size, Channels, Length)
output = conv_sequence(input_tensor)
output.shape
#summary(conv_sequence, input_size=(16, 50))
"""

In [None]:
class EEG_Encoder(nn.Module):

    def __init__(self,

                 in_channels = 64, n_subjects = 30,

                 # subject layers
                 out_channels = 128,

                 # conv
                 conv_chanels = [16, 32, 64, 128],
                 kernel = 4, stride = 2, conv_dropout = 0,
                 batch_norm = False, dropout_input = 0,
                 leakiness = 0,

                 # lstm
                 hidden_size = 128, lstm_layers = 4, lstm_dropout = 0.1,

                 # attention
                 attention_heads = 4, subject_dim = 64, embedding_scale = 1.0):

        super().__init__()

        # subject layers (normalization across subjects)
        self.subject_layers = SubjectLayers(in_channels, out_channels, n_subjects)

        # scaled embedding (optional)
        # self.subject_embedding = ScaledEmbedding(n_subjects, subject_dim, embedding_scale)

        # convsequence
        self.convs = ConvSequence(channels=conv_chanels, kernel = kernel_size,
                                  stride = stride, dropout = conv_dropout,
                                  batch_norm = batch_norm, leakiness = leakiness,
                                  dropout_input = dropout_input)

        # lstm
        self.lstm = nn.LSTM(input_size    = out_channels,
                            hidden_size   = hidden_size//2,
                            num_layers    = lstm_layers,
                            dropout       = lstm_dropout,
                            bidirectional = True,
                            batch_first    = True)

        # attention
        self.attention = Attention(hidden_size, heads=attention_heads)

        # final linear layer
        self.finalconv1 = nn.Conv1d(hidden_size, out_channels, kernel_size = 1)

    def forward(self, eeg_inputs): # to pass as an additional paramater n_subjects=10

        print('eeg_inputs before convs', eeg_inputs.shape)
        out = self.convs(eeg_inputs)
        print('eeg_inputs after convs', out.shape)
        print("After convs:", torch.isnan(out).any())

        # lstm
        out = out.permute(2, 0, 1)
        out, _ = self.lstm(out)
        out = out.permute(1, 2, 0)
        print("After LSTM:", torch.isnan(out).any())

        # attention
        out = out + self.attention(out)
        print("After Attention:", torch.isnan(out).any())

        # final
        out = self.finalconv1(out)
        print("Final Output:", torch.isnan(out).any())

        return out

# TEST
in_channels = 64
n_subjects = 30
out_channels = 128
conv_channels = [64, 32, 64, 128]
kernel_size = 4
stride = 2
conv_dropout = 0
batch_norm = False
dropout_input = 0
leakiness = 0
hidden_size = 128
lstm_layers = 4
lstm_dropout = 0.1
attention_heads = 4
subject_dim = 64
embedding_scale = 1.0

eeg_encoder = EEG_Encoder(in_channels, n_subjects, out_channels, conv_channels, kernel_size, stride,
                          conv_dropout, batch_norm, dropout_input, leakiness, hidden_size, lstm_layers,
                          lstm_dropout, attention_heads, subject_dim, embedding_scale)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eeg_encoder.to(device)
batch_size = 2
in_channels = 64
sequence_length = 128
x_sample = torch.rand(batch_size, in_channels, sequence_length)
summary(eeg_encoder, x_sample.to(device))

## decoder

audio -> wav2vec2 -> embedding space <- our model <- eeg

once in common embedding space, use CLIP

From CLIP --> decode to words

# losses and metrics

## clip [DS]

## wer [CM]

### general

### vocab-specific

# trainer


In [None]:
class Trainer:
    def __init__(self, model, loader, optimizer, criterion, scheduler, max_epochs= 1, run_id= 'exp'):

        self.model      = model
        self.loader     = loader
        self.optimizer  = optimizer
        self.criterion  = criterion
        self.scheduler = scheduler

        self.train_losses           = []
        self.val_losses             = []
        self.prediction_probs       = []
        self.prediction_probs_test  = []
        self.generated_texts_test   = []
        self.epochs                 = 0
        self.max_epochs             = max_epochs
        self.run_id                 = run_id


    def calculate_loss(self, out, target):
        # output: (B, T, Vocab_size) - probability distributions
        # target: (B, T)
        # Read the documentation of CrossEntropyLoss and try to understand how it takes inputs

        # Tip: If your target is of shape (B, T) it means that you have B batches with T words.
        # Tip: What is the total number of words in this batch?
        # Tip: Crossentropy calculates the loss between a label and its probability distribution.

        out     = out.view(-1, out.size(-1)) # TODO
        targets = target.view(-1)  # TODO
        loss    = self.criterion(out, targets)

        return loss


    def train(self):

        self.model.train() # set to training mode
        self.model.to(DEVICE)
        epoch_loss  = 0
        num_batches = 0

        for batch_num, (inputs, targets) in enumerate(tqdm(self.loader)):

            # TODO: Complete the loop. You should be able to complete this without any helper comments after 3 HWs
            # Tip: Mixed precision training
            # For loss calculation, use the calculate_loss function. You need to complete it before using.

            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            self.optimizer.zero_grad()

            output, _ = self.model(inputs)
            loss = self.calculate_loss(output, targets)
            loss.backward()
            self.optimizer.step()

            loss_val = loss.item()
            epoch_loss += loss_val
            num_batches += 1

        #epoch_loss = epoch_loss / (batch_num + 1)
        epoch_loss = epoch_loss / num_batches
        self.epochs += 1
        print('[TRAIN] \tEpoch [%d/%d] \tLoss: %.4f \tLr: %.6f'
                      % (self.epochs, self.max_epochs, epoch_loss, self.optimizer.param_groups[0]['lr']))
        self.train_losses.append(epoch_loss)


    def test(self): # Don't change this function

        self.model.eval() # set to eval mode
        prediction_probs     = self.model.predict(fixtures_pred['inp']).detach().cpu().numpy() # get predictions
        self.prediction_probs.append(prediction_probs)

        generated_indexes_test   = self.model.generate(fixtures_gen_test, 10).detach().cpu().numpy() # generated predictions for 10 words

        nll                   = get_prediction_nll(prediction_probs, fixtures_pred['out'])
        generated_texts_test  = make_generation_text(fixtures_gen_test, generated_indexes_test, VOCAB)
        self.val_losses.append(nll)

        self.generated_texts_test.append(generated_texts_test)

        # generate predictions for test data
        prediction_probs_test = self.model.predict(fixtures_pred_test['inp']).detach().cpu().numpy() # get predictions
        self.prediction_probs_test.append(prediction_probs_test)

        print('[VAL] \tEpoch [%d/%d] \tLoss: %.4f'
                      % (self.epochs, self.max_epochs, nll))
        return nll


    def save(self): # Don't change this function

        model_path = os.path.join('hw4/experiments', self.run_id, 'model-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.model.state_dict()}, model_path)
        np.save(os.path.join('hw4/experiments', self.run_id, 'prediction-probs-{}.npy'.format(self.epochs)), self.prediction_probs[-1])
        np.save(os.path.join('hw4/experiments', self.run_id, 'prediction-probs-test-{}.npy'.format(self.epochs)), self.prediction_probs_test[-1])

        with open(os.path.join('hw4/experiments', self.run_id, 'generated-texts-{}-test.txt'.format(self.epochs)), 'w') as fw:
            fw.write(self.generated_texts_test[-1])

# experiments

# evaluation

## viz