# Define some paths

In [1]:
# The base path where all outputs are stored. You can create a folder or point to an existing accessible location.
YOUR_BASE_PATH='shared_fs'
CHECKPOINT_STORAGE="/run/determined/workdir/shared_fs/checkpoints"

# The following paths will be added to the base path (e.g. audio files will be stored in 'shared_fs/audio_files'). 
AUDIO_PATH='audio_schmidt' # where audio files are downloaded
SNIPPETS_PATH='schmidt' # where 10s clips are stored

# Only populate one of the following: either a YouTube playlist or a single video. Make sure the video doesn't have an age restriction, or you will have to log in
PLAYLIST_URL='' #
SINGLE_VIDEO_URL='' 

# Audio Downloading + Model Training Installations

In [2]:
!apt-get install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.9.2-1).
0 upgraded, 0 newly installed, 0 to remove and 19 not upgraded.


In [3]:
!git clone https://huggingface.co/fnlp/SpeechTokenizer
!pip3 install speechtokenizer

fatal: destination path 'SpeechTokenizer' already exists and is not an empty directory.
[0m

In [4]:
!pip install datasets
!pip install mamba-ssm==1.1.0
!pip install huggingface_hub
!pip install torchinfo
!pip install wandb

Collecting argparse (from buildtools->causal-conv1d>=1.1.0->mamba-ssm==1.1.0)
  Downloading argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0
[0m

In [5]:
!pip install pytube
!pip install moviepy
!pip install soundfile

[0m

# Define some tokenization functions

In [6]:
from speechtokenizer import SpeechTokenizer
import soundfile as sf
import torchaudio
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_QUANTIZERS_USED = 4
batch_size = 4

config_path = 'SpeechTokenizer/speechtokenizer_hubert_avg/config.json'
ckpt_path = 'SpeechTokenizer/speechtokenizer_hubert_avg/SpeechTokenizer.pt'
speech_tokenizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path).to(device)
speech_tokenizer.eval()


def normalize_waveform(waveform, sr):
    if len(waveform.shape) == 2 and waveform.shape[1] > 0:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    waveform = waveform.reshape(1, -1)
    waveform = torchaudio.functional.resample(waveform, sr, speech_tokenizer.sample_rate)
    return waveform


# Expected input: waveform of shape (B, C, V)
# The encoding returns shape (num codebooks, 1, timestep). Using it with batch size > 1 gave me a lot of errors so I didn't try it.
# After encoding the tokens, I threw away some of the later codebooks which encode less information before I flattened them as explained below.
def tokenize(waveform):
    with torch.no_grad():
        codes = speech_tokenizer.encode(waveform.unsqueeze(0).to(device))  # codes: (n_q, B, T)
    semantic_tokens = codes[:NUM_QUANTIZERS_USED, :, :].cpu()
    return flatten_tokens(semantic_tokens)

def save_waveform(filename, waveform):
  torchaudio.save(filename, waveform[0].detach().cpu(), 16000)

def decode_tokens(tokens):
  unflattened_tokens = unflatten_tokens(tokens)
  return speech_tokenizer.decode(unflattened_tokens)

def save_to_file(tok, filename):
  outputwav = decode_tokens(tok.detach().to(device))
  save_waveform(filename, outputwav)

# Transposing the timestep and code books before flattening to have it be a1, b1, c1 instead of a1, a2, a3, b1, b2, b3,
# since I'm throwing away some of the codebooks, and I also want to be able to generate based on timestep
def flatten_tokens(tokens):
    n_q, B, T = tokens.shape
    transpose_tokens = tokens.transpose(0, 2)
    return transpose_tokens.reshape(B, T * NUM_QUANTIZERS_USED)


def unflatten_tokens(tokens):
    B, L = tokens.shape
    T = L // NUM_QUANTIZERS_USED
    return tokens.reshape(T, B, NUM_QUANTIZERS_USED).transpose(0, 2)

# Download Audio

In [7]:
from pytube import Playlist, YouTube

def download_audio_from_playlist(playlist_url, output_path):
    playlist = Playlist(playlist_url)
    for video in playlist.videos:
        audio_stream = video.streams.get_audio_only()
        audio_stream.download(output_path=output_path, filename=video.title + ".mp4")

def download_audio_from_video(video_url, output_path):
    video = YouTube(video_url)
    audio_stream = video.streams.get_audio_only()
    audio_stream.download(output_path=output_path, filename=video.title + ".mp4")
    
playlist_url = PLAYLIST_URL
video_url = SINGLE_VIDEO_URL

if playlist_url:
    download_audio_from_playlist(playlist_url, f"{YOUR_BASE_PATH}/{AUDIO_PATH}")
elif video_url:
    download_audio_from_video(video_url, f"{YOUR_BASE_PATH}/{AUDIO_PATH}")

# Install audio processing tools

In [8]:
!apt-get update
!apt-get install -y ffmpeg

Hit:1 http://archive.ubuntu.com/ubuntu focal InRelease
Hit:2 https://packages.cloud.google.com/apt cloud-sdk InRelease
Get:3 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Hit:4 http://security.ubuntu.com/ubuntu focal-security InRelease
Hit:5 http://archive.ubuntu.com/ubuntu focal-backports InRelease
Get:6 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 Packages [1461 kB]
Get:7 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages [3770 kB]
Get:8 http://archive.ubuntu.com/ubuntu focal-updates/restricted amd64 Packages [3294 kB]
Get:9 http://archive.ubuntu.com/ubuntu focal-updates/multiverse amd64 Packages [32.5 kB]
Fetched 8672 kB in 2s (4938 kB/s)                          
Reading package lists... Done
Reading package lists... Done
Building dependency tree       
Reading state information... Done
ffmpeg is already the newest version (7:4.2.7-0ubuntu0.1).
0 upgraded, 0 newly installed, 0 to remove and 19 not upgraded.


In [9]:
!ffmpeg -version

ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers
built with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)
configuration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enabl

# Split audio into 10s snippets

In [10]:
import subprocess
import os

def extract_and_split_audio(mp4_file, output_dir, clip_length=10):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    base_name = os.path.splitext(os.path.basename(mp4_file))[0]
    output_format = "wav"
    temp_audio_file = os.path.join(output_dir, f"{base_name}_temp.{output_format}")

    # Extract the audio from the video, suppressing error messages
    with open(os.devnull, 'w') as FNULL:
        subprocess.run(["ffmpeg", "-i", mp4_file, "-q:a", "0", "-map", "a", temp_audio_file], 
                       stdout=FNULL, stderr=subprocess.STDOUT, check=True)

    # Get the duration of the audio file
    result = subprocess.run(["ffprobe", "-v", "error", "-show_entries", 
                             "format=duration", "-of", 
                             "default=noprint_wrappers=1:nokey=1", temp_audio_file], 
                            text=True, capture_output=True)
    total_duration = float(result.stdout)

    # Split the audio file into chunks
    for start in range(0, int(total_duration), clip_length):
        end = min(start + clip_length, int(total_duration))
        output_file = os.path.join(output_dir, f"{base_name}_clip_{start}_{end}.{output_format}")
        with open(os.devnull, 'w') as FNULL:
            subprocess.run(["ffmpeg", "-y", "-i", temp_audio_file, "-ss", str(start), "-to", 
                            str(end), "-c", "copy", output_file], 
                           stdout=FNULL, stderr=subprocess.STDOUT, check=True)

    # Remove the temporary audio file
    os.remove(temp_audio_file)

def process_all_files(input_dir, output_dir):
    for file in os.listdir(input_dir):
        if file.endswith('.mp4'):
            mp4_file_path = os.path.join(input_dir, file)
            extract_and_split_audio(mp4_file_path, output_dir)

In [11]:
input_directory = f"{YOUR_BASE_PATH}/{AUDIO_PATH}" # Folder where your MP4 files are located
output_directory = f"{YOUR_BASE_PATH}/{SNIPPETS_PATH}"  # Folder where you want to save the clips

process_all_files(input_directory, output_directory)

# Normalize + tokenize the waveforms

In [12]:
!pip install librosa

[0m

In [13]:
!mkdir shared_fs/testfiles

mkdir: cannot create directory ‘shared_fs/testfiles’: File exists


In [14]:
from datasets import load_dataset
import torch
import numpy as np


print("Loading Dataset")
# Select subset from 10s clips
audio_dataset = load_dataset("audiofolder", data_dir=output_directory)["train"]

print(audio_dataset)


print("Normalizing the waveforms")
audio_dataset = audio_dataset.map(
    lambda x: {
        "original_sampling_rate": x["audio"]["sampling_rate"],
        "audio_array": normalize_waveform(
            torch.tensor(x["audio"]["array"]), x["audio"]["sampling_rate"]
        ),
    },
    remove_columns=["audio"],
    # keep_in_memory=True,
    writer_batch_size=15000,
)

print(audio_dataset)


print("Making sure the dataset is in the correct format")

def standardize_audio_length_with_tolerance(example, expected_length):
    audio_array = example["audio_array"]
    if not isinstance(audio_array, torch.Tensor):
        audio_array = torch.tensor(audio_array)

    # Assuming audio_array is 2D: [channels, length]
    current_length = audio_array.shape[1]

    if current_length < expected_length:
        # Pad with zeros. The padding size should match the number of channels
        padding_length = expected_length - current_length
        padding = torch.zeros((audio_array.shape[0], padding_length), dtype=audio_array.dtype)
        audio_array = torch.cat([audio_array, padding], dim=1)
    elif current_length > expected_length:
        # Trim the excess length
        audio_array = audio_array[:, :expected_length]

    return {"audio_array": audio_array}

# Apply the function
audio_dataset = audio_dataset.map(
    lambda x: standardize_audio_length_with_tolerance(x, expected_length=160000),
    batched=False
)


print(audio_dataset)

print("Tokenizing the waveforms")
audio_dataset = audio_dataset.map(
    lambda x: {"tokens": tokenize(torch.tensor(x["audio_array"]))
               },
    remove_columns=[
        "audio_array",
    ],
    writer_batch_size=15000,
)

print(audio_dataset)

# Checking the files to see if the tokenization worked correctly.
for idx, t in enumerate(audio_dataset.select(range(0, 10))):
    save_to_file(torch.tensor(t["tokens"]).to(device), f"{YOUR_BASE_PATH}/testfiles/{idx}_test.wav")

Loading Dataset


Resolving data files:   0%|          | 0/193 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['audio'],
    num_rows: 193
})
Normalizing the waveforms


Map:   0%|          | 0/193 [00:00<?, ? examples/s]

Dataset({
    features: ['original_sampling_rate', 'audio_array'],
    num_rows: 193
})
Making sure the dataset is in the correct format


Map:   0%|          | 0/193 [00:00<?, ? examples/s]

Dataset({
    features: ['original_sampling_rate', 'audio_array'],
    num_rows: 193
})
Tokenizing the waveforms


Map:   0%|          | 0/193 [00:00<?, ? examples/s]

Dataset({
    features: ['original_sampling_rate', 'tokens'],
    num_rows: 193
})


# Create train/test dataloaders

In [15]:
print(audio_dataset)

audio_dataset = audio_dataset.with_format('torch')
audio_dataset = audio_dataset.train_test_split(0.05)

# Setting up the train and test dataloader
train_dataloader = torch.utils.data.DataLoader(audio_dataset['train'], batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(audio_dataset['test'], batch_size=batch_size, shuffle=True)

Dataset({
    features: ['original_sampling_rate', 'tokens'],
    num_rows: 193
})


# Define model + training functions

In [16]:
from tqdm import tqdm
from mamba_ssm import Mamba
import matplotlib.pyplot as plt
import numpy as np
import torchinfo
import torchaudio
from torch.utils.data import DataLoader
import os
import torch
import torch.nn as nn
from torch.nn import  functional as F
#hyperparams

epochs = 20
lr = 1e-3
block_size = 2000
device = "cuda" if torch.cuda.is_available() else "cpu"
max_iters = 10000
print_iters = 100
eval_iters = 10
eval_interval = 300
n_embed=384
n_heads = 6
n_layers = 6
dropout = 0.2
vocab_size=1024
from tqdm import tqdm

# ---------




class SelfAttentionHead(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.keys = nn.Linear(n_embed, head_size)
    self.queries = nn.Linear(n_embed, head_size)
    self.values = nn.Linear(n_embed, head_size)
    self.head_size = head_size
    self.n_embed = n_embed
    self.register_buffer('tril', torch.tril(torch.ones((block_size,block_size))).to(device))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape
    k = self.keys(x) # (B,T,C_h)
    q = self.queries(x) # (B,T,C_h)
    v = self.values(x) # (B,T,C_h)
    wei = k @ q.transpose(-1,-2) * C**(-0.5)# (B,T,T)
    wei = wei.masked_fill( self.tril[:T,:T]==0, float('-inf'))
    # wei = F.softmax(wei, dim=-1) # (B,T,T)
    wei = torch.log(torch.exp(wei)+1) # (B,T,T)
    wei = self.dropout(wei)
    out = wei @ v # (B,T,C_h)
    return out


class MultiHeadAttention(nn.Module):
  def __init__(self, n_heads, head_size) -> None:
    super().__init__()
    self.heads = nn.ModuleList([SelfAttentionHead(head_size) for _ in range(n_heads)])
    self.proj = nn.Linear(n_embed, n_embed)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape
    out = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.proj(out)
    out = self.dropout(out)
    return out

class FeedForward(nn.Module):
  def __init__(self, n_embed) -> None:
    super().__init__()
    self.ffn = nn.Sequential(
      nn.Linear(n_embed, 4*n_embed),
      nn.ReLU(),
      nn.Linear(4*n_embed, n_embed),
      nn.Dropout(dropout),
    )
  def forward(self, x):
    return self.ffn(x)

class Block(nn.Module):
  def __init__(self, n_embed, n_heads) -> None:
    super().__init__( )
    self.head_size = n_embed // n_heads
    # self.sa_head = MultiHeadAttention(n_heads, self.head_size)
    self.sa_head = Mamba(
      # This module uses roughly 3 * expand * d_model^2 parameters
      d_model=n_embed, # Model dimension d_model
      d_state=16,  # SSM state expansion factor
      d_conv=4,    # Local convolution width
      expand=1,    # Block expansion factor
  ).to("cuda")
    self.ffn = FeedForward(n_embed)
    self.ln1 = nn.LayerNorm(n_embed)
    self.ln2 = nn.LayerNorm(n_embed)


  def forward(self, x):
    x = x + self.sa_head(self.ln1(x))
    x = x + self.ffn(self.ln2(x))

    return x

class MambaAudioModel(nn.Module):
  def __init__(self,vocab_size):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size,n_embed)
    self.position_embedding_table = nn.Embedding(block_size,n_embed)
    self.lm_head = nn.Linear(n_embed,vocab_size)
    self.ffn = FeedForward(n_embed)
    print("layers", n_layers)
    self.blocks = nn.Sequential(*[Block(n_embed,n_heads=n_heads) for _ in range(n_layers)])


  def forward(self, idx, targets=None):
    # idx = idx[:,-block_size:]
    B,T = idx.shape
    tok_emb = self.token_embedding_table(idx) # (B,T, C_e)
    pos_emb = self.position_embedding_table(torch.arange(T,device=device)) # (T, C_e)
    x = tok_emb + pos_emb # (B,T,Q, C_e)
    x = self.blocks(x) # (B,T,Q, C_e)
    logits = self.lm_head(x) # (B,T,vocab_size)
    if targets is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
      logits = logits.view(B,T,C)
    return logits, loss


def estimate_test_loss(model, dataset):
  model.eval()
  test_losses = []
  with torch.no_grad():
    for tokens in tqdm(dataset['test']['tokens']):
      tokens = tokens.to(device)
      x = tokens[:,:-1].contiguous()
      y = tokens[:,1:].contiguous()
      logits, loss = model(x,y )
      test_losses.append(loss)
  model.train()
  return sum(test_losses)/len(test_losses)

# Install Determined and verify

In [17]:
!pip install --upgrade determined
!pip install tensorboard

[0m

In [18]:
import determined as det
!det --version

det 0.27.0


# Train model using Detached Mode (report metrics and store checkpoints in Determined)

In [19]:
import torchinfo
from tqdm.notebook import tqdm
import random
from determined.experimental import core_v2


# initialize core context
core_v2.init(
    defaults=core_v2.DefaultConfig(
        name="schmidt",
        checkpoint_storage=CHECKPOINT_STORAGE,
    ),
    
    unmanaged=core_v2.UnmanagedConfig(
        external_experiment_id="2",
    ),
)

    
# initialize model 
model = MambaAudioModel(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=lr)


# load checkpoint helper function
def load_state(checkpoint_directory, trial_id):
    checkpoint_directory = pathlib.Path(checkpoint_directory)

    with checkpoint_directory.joinpath("checkpoint.pt").open("rb") as f:
        model = torch.load(f)
    with checkpoint_directory.joinpath("state").open("r") as f:
        epochs_completed, ckpt_trial_id = [int(field) for field in f.read().split(",")]

    if ckpt_trial_id != trial_id:
        epochs_completed = 0

    return model, epochs_completed


initial_i = 0

# print model + put in train mode
print(torchinfo.summary(model))
model.train()


# training loop 
ind = initial_i

for epoch in tqdm(range(epochs)):
    for batch in tqdm(train_dataloader):
        
        checkpoint_metadata_dict = {"steps_completed": ind}

        if batch['tokens'].shape[-1] != block_size:
            continue
        batch_size = batch['tokens'].shape[0]
        tokens = batch['tokens'].to(device).reshape(batch_size,block_size)

        x = tokens[:,:-1].contiguous()
        y = tokens[:,1:].contiguous()
        logits, loss = model(x,y)

        if ind % 5 == 0:
            print(loss)
            core_v2.train.report_training_metrics(steps_completed=ind, metrics={"loss": loss.item()})

        if ind % 10 == 0:
            tl = estimate_test_loss(model, audio_dataset)
            print("testloss", tl)
            core_v2.train.report_validation_metrics(steps_completed=ind, metrics={"loss": tl.item()})


        if ind % 100 == 0:
            with core_v2.checkpoint.store_path(checkpoint_metadata_dict) as (path, storage_id):
                torch.save(model.state_dict(), path / "checkpoint.pt")
                with path.joinpath("state").open("w") as f:
                    f.write(f"{epochs},{core_v2.info.trial.trial_id}")
                    
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        # Prevents gradient explosion.
        torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        ind += 1

           
checkpoint_metadata_dict = {"steps_completed": ind}
with core_v2.checkpoint.store_path(checkpoint_metadata_dict) as (path, storage_id):
    torch.save(model.state_dict(), path / "checkpoint.pt")
    with path.joinpath("state").open("w") as f:
        f.write(f"{epoch},{core_v2.info.trial.trial_id}")
                
core_v2.close()
print(torchinfo.summary(model))

TensorFlow writer not found


layers 6
Layer (type:depth-idx)                   Param #
MambaAudioModel                          --
├─Embedding: 1-1                         393,216
├─Embedding: 1-2                         768,000
├─Linear: 1-3                            394,240
├─FeedForward: 1-4                       --
│    └─Sequential: 2-1                   --
│    │    └─Linear: 3-1                  591,360
│    │    └─ReLU: 3-2                    --
│    │    └─Linear: 3-3                  590,208
│    │    └─Dropout: 3-4                 --
├─Sequential: 1-5                        --
│    └─Block: 2-2                        --
│    │    └─Mamba: 3-5                   481,920
│    │    └─FeedForward: 3-6             1,181,568
│    │    └─LayerNorm: 3-7               768
│    │    └─LayerNorm: 3-8               768
│    └─Block: 2-3                        --
│    │    └─Mamba: 3-9                   481,920
│    │    └─FeedForward: 3-10            1,181,568
│    │    └─LayerNorm: 3-11              768
│    │    

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/46 [00:00<?, ?it/s]

tensor(7.3274, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.3214, device='cuda:0')
tensor(7.2452, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(6.3314, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.6630, device='cuda:0')
tensor(6.5034, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(6.3357, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.3971, device='cuda:0')
tensor(6.2270, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(6.0548, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.1221, device='cuda:0')
tensor(6.0560, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(5.9306, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.7345, device='cuda:0')
tensor(5.4604, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(5.4669, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.5598, device='cuda:0')
tensor(5.2266, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(5.2361, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.4697, device='cuda:0')
tensor(5.7999, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(5.3210, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.3312, device='cuda:0')
tensor(5.0460, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.7687, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.2576, device='cuda:0')
tensor(4.8634, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.9150, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.1542, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(4.2283, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.8406, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.0932, device='cuda:0')
tensor(3.5498, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.0672, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.0236, device='cuda:0')
tensor(4.0991, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.6609, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.9650, device='cuda:0')
tensor(4.1537, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.3385, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8576, device='cuda:0')
tensor(3.7033, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(3.8732, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8087, device='cuda:0')
tensor(3.4002, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.8893, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8471, device='cuda:0')
tensor(3.8903, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.0163, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8339, device='cuda:0')
tensor(3.7977, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.9014, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.7898, device='cuda:0')
tensor(4.1356, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.9349, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.7563, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(3.3075, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3871, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8133, device='cuda:0')
tensor(3.4674, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3010, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8398, device='cuda:0')
tensor(3.6614, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.0097, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8685, device='cuda:0')
tensor(3.3302, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.4851, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8366, device='cuda:0')
tensor(3.5553, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(2.4692, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(4.8628, device='cuda:0')
tensor(2.8221, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.6306, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.1450, device='cuda:0')
tensor(2.3788, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.7690, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.1891, device='cuda:0')
tensor(3.1423, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.7213, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.1270, device='cuda:0')
tensor(2.6221, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.8605, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.1706, device='cuda:0')
tensor(2.9669, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(2.0293, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.4514, device='cuda:0')
tensor(2.1230, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1256, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.6783, device='cuda:0')
tensor(1.8620, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1069, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.6810, device='cuda:0')
tensor(2.4501, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2133, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.6684, device='cuda:0')
tensor(2.4362, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2927, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(5.5295, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(1.3760, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3112, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.2314, device='cuda:0')
tensor(1.3527, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3112, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.3075, device='cuda:0')
tensor(1.4057, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.7306, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.2243, device='cuda:0')
tensor(1.6171, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.6427, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.2594, device='cuda:0')
tensor(2.0143, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.7156, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.8281, device='cuda:0')
tensor(1.0274, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7264, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(6.9608, device='cuda:0')
tensor(0.8317, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.9220, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.1623, device='cuda:0')
tensor(0.7972, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.1321, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.0508, device='cuda:0')
tensor(1.1800, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3342, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.0362, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.4312, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.5625, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.5136, device='cuda:0')
tensor(0.5740, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.5326, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.6751, device='cuda:0')
tensor(0.7098, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.5796, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.9220, device='cuda:0')
tensor(0.6462, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6880, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.8949, device='cuda:0')
tensor(0.7127, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.3183, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(7.7815, device='cuda:0')
tensor(0.3297, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3585, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.3666, device='cuda:0')
tensor(0.3297, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.4447, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.2885, device='cuda:0')
tensor(0.3876, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.4692, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.2495, device='cuda:0')
tensor(0.5454, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.5100, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.2772, device='cuda:0')
tensor(0.5674, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.1839, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.8774, device='cuda:0')
tensor(0.2070, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3091, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.6266, device='cuda:0')
tensor(0.3703, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3956, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.9492, device='cuda:0')
tensor(0.3610, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3656, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.7607, device='cuda:0')
tensor(0.4998, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.4375, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(8.8918, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.2161, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2410, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.0176, device='cuda:0')
tensor(0.2257, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3318, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.1463, device='cuda:0')
tensor(0.2946, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2841, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.0537, device='cuda:0')
tensor(0.3595, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3786, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.2075, device='cuda:0')
tensor(0.4743, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.1428, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.2174, device='cuda:0')
tensor(0.2037, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1931, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.4660, device='cuda:0')
tensor(0.1787, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2188, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.3123, device='cuda:0')
tensor(0.2499, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2791, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.4381, device='cuda:0')
tensor(0.2769, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3486, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.3075, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.1139, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1330, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.4435, device='cuda:0')
tensor(0.1277, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1643, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.6835, device='cuda:0')
tensor(0.1656, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2363, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.6607, device='cuda:0')
tensor(0.2714, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2615, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.7269, device='cuda:0')
tensor(0.1991, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.1311, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.4810, device='cuda:0')
tensor(0.1633, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1330, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.7788, device='cuda:0')
tensor(0.1731, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1740, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.8275, device='cuda:0')
tensor(0.1687, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2031, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.8827, device='cuda:0')
tensor(0.2195, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2277, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.9340, device='cuda:0')
tensor(0.2406, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.1029, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.0759, device='cuda:0')
tensor(0.1361, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1224, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.1998, device='cuda:0')
tensor(0.1850, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1818, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.3068, device='cuda:0')
tensor(0.1798, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1660, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.0245, device='cuda:0')
tensor(0.2077, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2085, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(9.9488, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.1129, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1234, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.1613, device='cuda:0')
tensor(0.1220, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.0992, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.3017, device='cuda:0')
tensor(0.1384, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1224, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.3077, device='cuda:0')
tensor(0.1377, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1541, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.2291, device='cuda:0')
tensor(0.1710, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.0806, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.1294, device='cuda:0')
tensor(0.0891, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1286, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.2059, device='cuda:0')
tensor(0.1260, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1190, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.3213, device='cuda:0')
tensor(0.1185, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1161, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.3889, device='cuda:0')
tensor(0.1589, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1456, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.4010, device='cuda:0')


  0%|          | 0/46 [00:00<?, ?it/s]

tensor(0.0596, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.0694, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.4208, device='cuda:0')
tensor(0.0917, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.0872, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.5955, device='cuda:0')
tensor(0.0999, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1032, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.6544, device='cuda:0')
tensor(0.1190, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1187, device='cuda:0', grad_fn=<NllLossBackward0>)


  0%|          | 0/10 [00:00<?, ?it/s]

testloss tensor(10.6213, device='cuda:0')
tensor(0.1008, device='cuda:0', grad_fn=<NllLossBackward0>)
Layer (type:depth-idx)                   Param #
MambaAudioModel                          --
├─Embedding: 1-1                         393,216
├─Embedding: 1-2                         768,000
├─Linear: 1-3                            394,240
├─FeedForward: 1-4                       --
│    └─Sequential: 2-1                   --
│    │    └─Linear: 3-1                  591,360
│    │    └─ReLU: 3-2                    --
│    │    └─Linear: 3-3                  590,208
│    │    └─Dropout: 3-4                 --
├─Sequential: 1-5                        --
│    └─Block: 2-2                        --
│    │    └─Mamba: 3-5                   481,920
│    │    └─FeedForward: 3-6             1,181,568
│    │    └─LayerNorm: 3-7               768
│    │    └─LayerNorm: 3-8               768
│    └─Block: 2-3                        --
│    │    └─Mamba: 3-9                   481,920
│    │    └─F

# Test out model

In [20]:
# load checkpoint - replace checkpoint path to that of the desired checkpoint
model = MambaAudioModel(vocab_size).to(device)
model.load_state_dict(torch.load(f"{YOUR_BASE_PATH}/checkpoints/ec795afe-64dd-42e0-a0e0-a36ab3aeeb86/checkpoint.pt"))

layers 6


<All keys matched successfully>

In [21]:
# Generate new audio
def unconditional_generation(model):
    idx = torch.tensor([[10,]]).to(device)
    max_new_tokens = 1999
    idx_next = []
    for i in tqdm(range(max_new_tokens)):
        idx_cond = idx[:,-block_size:]
        logits, loss = model(idx_cond)
        last_timestep = logits[:,-1,:]
        probs = F.softmax(last_timestep, dim=1)
        next_index = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_index), dim=1)

    save_to_file(idx, f"{YOUR_BASE_PATH}/unconditional_test.wav")

unconditional_generation(model)

  0%|          | 0/1999 [00:00<?, ?it/s]