In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import pathlib
import glob
import soundfile as sf
import time
from datetime import datetime

In [2]:
def get_masker_helper(B, T, mask_prob=0.065, mask_length=10, device=None):
    if device is None:
        device = torch.device('cpu')
    
    if T <= 0 or B <= 0:
        raise ValueError("T and B must be positive integers.")
    
    mask = torch.zeros(B, T, dtype=torch.bool, device=device)
    mask_length = max(1, min(mask_length, T))

    target_num_mask = int(mask_prob * T)
    num_spans = max(1, target_num_mask // mask_length)

    for b in range(B):
        starts = torch.randint(0, T, (num_spans,), device=device)
        
        for start in starts:
            end = min(start + mask_length, T)
            mask[b, start:end] = True
    return mask

In [3]:
class LibriSpeechDataset(Dataset):
    def __init__(self, directory_path):
        self.directory = pathlib.Path(directory_path)
        self.samples = []

        for speaker in os.listdir(directory_path):
            speaker_path = os.path.join(directory_path, speaker)
            if not os.path.isdir(speaker_path):
                continue

            for book in os.listdir(speaker_path):
                chapter_path = os.path.join(speaker_path, book)
                if not os.path.isdir(chapter_path):
                    continue

                files = os.listdir(chapter_path)
                transcript_file = glob.glob(os.path.join(chapter_path, "*.trans.txt"))
                if not transcript_file:
                    transcript_file = glob.glob(os.path.join(chapter_path, "*.txt"))
                if not transcript_file:
                    continue

                transcript_file = transcript_file[0]
                transcript_dict = {}
                with open(transcript_file, 'r') as f:
                    for line in f:
                        id, text = line.strip().split(' ', 1)
                        transcript_dict[id] = text

                for file in files:
                    if file.lower().endswith('.flac'):
                        base_name = file.replace('.flac', '')
                        audio_path = os.path.join(chapter_path, file)

                        transcript = transcript_dict.get(base_name)
                        if transcript is None:
                            continue

                        self.samples.append({
                            'id': base_name,
                            'audio_path': audio_path,
                            'transcript': transcript
                        })

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        data, sr = sf.read(sample['audio_path'], dtype='float32')

        if data.ndim == 1:
            waveform = torch.from_numpy(data).unsqueeze(0)
        else:
            waveform = torch.from_numpy(data.T)
        
        w = waveform.squeeze(0)
        mean = w.mean()
        std = w.std() if w.std() > 1e-6 else 1.0
        w = (w - mean) / std

        waveform = w.unsqueeze(0)
        return {
            'id': sample['id'],
            'waveform': waveform,
            'sr': sr,
            'transcript': sample['transcript'],
            'path': sample['audio_path']
        }


def collate_fn(batch):
    waveforms = [item['waveform'].squeeze(0) for item in batch]  
    lengths = torch.tensor([len(waveform) for waveform in waveforms], dtype=torch.long)
    
    max_len = lengths.max().item()
    padded_audios = torch.zeros(len(waveforms), max_len)
    
    for i, waveform in enumerate(waveforms):
        padded_audios[i, :len(waveform)] = waveform
    
    return {
        'audio': padded_audios,
        'lengths': lengths
    }

In [4]:
class Encoder(torch.nn.Module):
    def __init__(self, 
                 seed=None, 
                 in_channels=1, 
                 conv_channels=[512, 512, 512, 512, 512, 512, 512],
                 kernel_sizes=[10, 3, 3, 3, 3, 2, 2],
                 strides=[5, 2, 2, 2, 2, 2, 2],
                 dropout=0.1,
                 sample_rate=16000):
        super().__init__()
        if seed is not None:
            torch.manual_seed(seed)

        self.in_channels = in_channels
        self.conv_channels = list(conv_channels)
        self.kernel_sizes = list(kernel_sizes)
        self.strides = list(strides)
        self.dropout_prob = dropout
        self.sample_rate = sample_rate

        assert(len(self.conv_channels) == len(self.kernel_sizes) == len(self.strides))
        self.paddings = [k // 2 for k in self.kernel_sizes]

        self.conv_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.activations = nn.ModuleList()
        prev_channels = self.in_channels

        for out_c, k, s, p in zip(self.conv_channels, self.kernel_sizes, self.strides, self.paddings):
            conv = nn.Conv1d(in_channels=prev_channels, out_channels=out_c, kernel_size=k, stride=s, padding=p, bias=False)
            self.conv_layers.append(conv)
            self.norm_layers.append(nn.LayerNorm(out_c, eps=1e-5))
            self.activations.append(nn.GELU())
            self.dropouts.append(nn.Dropout(p=self.dropout_prob))
            prev_channels = out_c
        self._init_weights()    

    def _init_weights(self):
        for conv in self.conv_layers:
            nn.init.kaiming_normal_(conv.weight, nonlinearity='relu')

        for ln in self.norm_layers:
            nn.init.ones_(ln.weight)
            nn.init.zeros_(ln.bias)

    def compute_output_lengths(self, input_lengths):
        if isinstance(input_lengths, int):
            L = torch.tensor([input_lengths], dtype=torch.long)
        elif isinstance(input_lengths, (list, tuple)):
            L = torch.tensor(list(input_lengths), dtype=torch.long)
        elif torch.is_tensor(input_lengths):
            L = input_lengths.to(torch.long).clone()
            if L.dim() == 0:
                L = L.unsqueeze(0)
        else:
            raise ValueError("input_lengths must be an int, list, tuple, or torch.Tensor")
        
        for k, s, p in zip(self.kernel_sizes, self.strides, self.paddings):
            L = (L + 2 * int(p) - int(k)) // int(s) + 1
            L = torch.clamp(L, min=1)
        return L.long()
    
    def forward(self, x, input_lengths=None):
        """
        input_lengths are in number of samples
        output is in frames
        padding mask returns true for padded positions
        """
        if (x.dim() == 2):
            x = x.unsqueeze(1)
        elif (x.dim() == 3):
            pass
        else:
            raise ValueError("Input tensor must be (B, T) or (B, 1, T)")
        
        for i, (conv, norm, activ, drop) in enumerate(zip(self.conv_layers, self.norm_layers, self.activations, self.dropouts)):
            x = conv(x)                          # (B, C_out, T_out)
            x = x.transpose(1,2)                 # (B, T_out, C_out)
            x = norm(x)                          
            x = activ(x)
            x = drop(x)
            x = x.transpose(1, 2)                # (B, C_out, T_out)

        features = x.transpose(1, 2).contiguous()        # (B, T_encoder, C)

        if input_lengths is None:
            return features
        
        output_lengths = self.compute_output_lengths(input_lengths)
        output_lengths = output_lengths.to(features.device)

        B, T_e, _ = features.shape
        if (output_lengths > T_e).any():
            output_lengths = torch.clamp(output_lengths, max=T_e)

        pos = torch.arange(T_e, device=features.device).unsqueeze(0).expand(B, T_e)
        padded_mask = pos >= output_lengths.unsqueeze(1)

        assert padded_mask.shape == (B, T_e)
        assert output_lengths.dtype == torch.long

        return features, output_lengths, padded_mask

In [5]:
class Masker(nn.Module):
    def __init__(self, embed_dim, mask_prob=0.065, mask_length=10):
        super().__init__()
        self.embed_dim = embed_dim
        self.mask_prob = float(mask_prob)
        self.mask_length = int(mask_length)
        self.mask_emb = nn.Parameter(torch.randn(1, 1, embed_dim))

    def get_mask(self, B, T, device=None):
        if device is None:
            device = torch.device('cpu')
        return get_masker_helper(B, T, self.mask_prob, self.mask_length, device)
    
    def apply_mask(self, Z, mask):
        assert Z.dim() == 3, "Expected Z shape: [B, T, C]"
        B, T, C = Z.shape
        if mask.shape != (B, T):
            raise ValueError(f"Mask shape {mask.shape} does not match Z shape {Z.shape}")

        if not mask.any():
            b_idx = torch.empty(0, dtype=torch.long, device=Z.device)
            t_idx = torch.empty(0, dtype=torch.long, device=Z.device)
            return Z, b_idx, t_idx, mask

        Z_masked = Z.clone()
        Z_masked[mask] = self.mask_emb
        b_idx, t_idx = torch.where(mask)
        return Z_masked, b_idx, t_idx, mask
    
    def forward(self, Z):
        assert Z.dim() == 3, "Expected Z shape: [B, T, C]"
        B, T, C = Z.shape
        device = Z.device

        if not self.training:
            b_idx = torch.empty(0, dtype=torch.long, device=Z.device)
            t_idx = torch.empty(0, dtype=torch.long, device=Z.device)
            mask = torch.zeros(B, T, dtype=torch.bool, device=Z.device)
            return Z, b_idx, t_idx, mask

        mask = self.get_mask(B, T, device)
        return self.apply_mask(Z, mask)

In [6]:
class Quantizer(nn.Module):
    def __init__(self, input_dim, G=2, V=320, temp_init=2.0, temp_min=0.5, temp_decay=0.999995):
        super(Quantizer, self).__init__()
        assert input_dim % G == 0, "Input dimension must be divisible by number of groups G."
        self.G = G  
        self.V = V  
        self.input_dim = input_dim
        self.group_dim = input_dim // G
        self.codebooks = nn.Parameter(torch.randn(G, V, input_dim // G))
        self.proj = nn.Linear(input_dim, G * V)
        self.register_buffer("_temperature", torch.tensor(float(temp_init)), persistent=True)
        self.temp_min = float(temp_min)
        self.temp_decay = float(temp_decay)

    @property
    def temperature(self):
        return float(self._temperature.item())

    def update_temperature(self):
        new_temp = max(self.temp_min, self.temperature * self.temp_decay)
        self._temperature.fill_(new_temp)

    def forward(self, x):
        if x.dim() != 3:
            raise ValueError("Expected shape: [B, T, C]")
        
        logits = self.proj(x)                              # [B, T, G * V]
        B, T, _ = logits.shape
        logits = logits.view(B, T, self.G, self.V)         # [B, T, G, V]
        probs = F.gumbel_softmax(logits=logits, tau=self.temperature, dim=-1, hard=True)
        cb = self.codebooks.unsqueeze(0).unsqueeze(0)    # [1, 1, G, V, C/G]
        probs_expanded = probs.unsqueeze(-1)                  # [B, T, G, V, 1]
        quantized = torch.sum(probs_expanded * cb, dim=-2)   # [B, T, G, C/G]
        quantized = quantized.view(B, T, self.input_dim)      # [B, T, C]

        soft = torch.softmax(logits, dim=-1)
        p_bar = soft.mean(dim=(0,1))  
        eps = 1e-9
        diversity_loss = (p_bar * torch.log(p_bar + eps)).sum() / (self.G * self.V)
        self.update_temperature()
        return quantized, diversity_loss

In [7]:
class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_dim, dropout=.1):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim=embed_dim,
                                                    num_heads=num_heads,
                                                    dropout=dropout,
                                                    batch_first=True)
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(nn.Linear(embed_dim, ffn_dim),
                                 nn.GELU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(ffn_dim, embed_dim),
                                 nn.Dropout(dropout))
        
    def forward(self, x, padding_mask=None):
        residual = x
        x = self.norm1(x)
        x, _ = self.self_attention(x, x, x, key_padding_mask=padding_mask)
        x = x + residual

        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x
        return x

In [8]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, embed_dim) * 0.02)
    
    def forward(self, x):
        B, T, C = x.shape
        return x + self.pos_emb[:, :T, :]

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=512, num_layers=12, num_heads=8, 
                 ffn_dim=2048, dropout=0.1, max_len=5000):
        super().__init__()
        self.pos_embedding = PositionalEmbedding(embed_dim, max_len)
        self.layers = nn.ModuleList([
            TransformerLayer(embed_dim, num_heads, ffn_dim, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x, padding_mask=None):
        x = self.pos_embedding(x)
        for layer in self.layers:
            x = layer(x, padding_mask)
        x = self.final_norm(x)
        
        return x

In [10]:
class Wav2Vec(nn.Module):
    def __init__(self,
                 in_channels=1,
                 conv_channels=[512, 512, 512, 512, 512, 512, 512],
                 kernel_sizes=[10, 3, 3, 3, 3, 2, 2],
                 strides=[5, 2, 2, 2, 2, 2, 2],
                 encoder_dropout=0.1,
                 mask_prob=0.065,
                 mask_length=10,
                 embed_dim=512,
                 num_transformer_layers=12,
                 num_heads=8,
                 ffn_dim=2048,
                 transformer_dropout=0.1,
                 num_groups=2,
                 num_vars=320,
                 temp_init=2.0,
                 temp_min=0.5,
                 temp_decay=0.999995,
                 sample_rate=16000,
                 seed=None):
        super().__init__()
        
        if seed is not None:
            torch.manual_seed(seed)
        
        self.encoder = Encoder(
            seed=seed,
            in_channels=in_channels,
            conv_channels=conv_channels,
            kernel_sizes=kernel_sizes,
            strides=strides,
            dropout=encoder_dropout,
            sample_rate=sample_rate
        )
        
        self.masker = Masker(
            embed_dim=embed_dim,
            mask_prob=mask_prob,
            mask_length=mask_length
        )
        
        self.transformer = TransformerEncoder(
            embed_dim=embed_dim,
            num_layers=num_transformer_layers,
            num_heads=num_heads,
            ffn_dim=ffn_dim,
            dropout=transformer_dropout,
            max_len=5000
        )
        
        self.quantizer = Quantizer(
            input_dim=embed_dim,
            G=num_groups,
            V=num_vars,
            temp_init=temp_init,
            temp_min=temp_min,
            temp_decay=temp_decay
        )
        
        self.projection = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, audio, input_lengths):
        features, output_lengths, padding_mask = self.encoder(audio, input_lengths)
        
        masked_features, mask_indices_b, mask_indices_t, mask = self.masker(features)
        
        contextualized = self.transformer(masked_features, padding_mask)
        
        quantized, diversity_loss = self.quantizer(features)
        
        return {
            'contextualized': contextualized,
            'quantized': quantized,
            'mask_indices_b': mask_indices_b,
            'mask_indices_t': mask_indices_t,
            'diversity_loss': diversity_loss,
            'features': features,
            'padding_mask': padding_mask,
            'output_lengths': output_lengths
        }
    
    def compute_contrastive_loss(self, outputs, num_negatives=100, temperature=0.1):
        contextualized = outputs['contextualized']
        quantized = outputs['quantized']
        mask_indices_b = outputs['mask_indices_b']
        mask_indices_t = outputs['mask_indices_t']
        
        if len(mask_indices_b) == 0:
            return torch.tensor(0.0, device=contextualized.device)
        
        c_masked = contextualized[mask_indices_b, mask_indices_t]  
        q_masked = quantized[mask_indices_b, mask_indices_t]       
        c_masked = self.projection(c_masked)  
        
        pos_logits = F.cosine_similarity(c_masked, q_masked, dim=-1)  
        pos_logits = pos_logits / temperature
        
        B, T, C = quantized.shape
        M = c_masked.shape[0]
        
        neg_indices = torch.randint(0, B * T, (M, num_negatives), device=quantized.device)
        quantized_flat = quantized.view(-1, C)
        negatives = quantized_flat[neg_indices] 
        
        c_expanded = c_masked.unsqueeze(1)  
        neg_logits = F.cosine_similarity(c_expanded, negatives, dim=-1)  
        neg_logits = neg_logits / temperature
        
        logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)  
        labels = torch.zeros(M, dtype=torch.long, device=logits.device)  
        
        contrastive_loss = F.cross_entropy(logits, labels)
        
        return contrastive_loss
    
    def compute_loss(self, outputs, num_negatives=100, temperature=0.1, diversity_weight=0.1):
        contrastive_loss = self.compute_contrastive_loss(outputs, num_negatives, temperature)
        diversity_loss = outputs['diversity_loss']
        
        total_loss = contrastive_loss + diversity_weight * diversity_loss
        
        loss_dict = {
            'total_loss': total_loss.mean().item() if total_loss.numel() > 1 else total_loss.item(),
            'contrastive_loss': contrastive_loss.mean().item() if contrastive_loss.numel() > 1 else contrastive_loss.item(),
            'diversity_loss': diversity_loss.mean().item() if diversity_loss.numel() > 1 else diversity_loss.item()
        }
        
        return total_loss, loss_dict

In [11]:
def train_epoch(model, dataloader, optimizer, device, epoch, num_negatives=100, temperature=0.1, diversity_weight=0.1):
    model.train()
    
    total_loss = 0.0
    total_contrastive = 0.0
    total_diversity = 0.0
    num_batches = 0
    
    start_time = time.time()
    
    for batch_idx, batch in enumerate(dataloader):
        audio = batch['audio'].to(device)
        lengths = batch['lengths'].to(device)
        
        outputs = model(audio, lengths)

        model_ref = model.module if isinstance(model, nn.DataParallel) else model
        loss, loss_dict = model_ref.compute_loss(
            outputs,
            num_negatives=num_negatives,
            temperature=temperature,
            diversity_weight=diversity_weight
        )

        optimizer.zero_grad()
        if loss.numel() > 1:
            loss = loss.mean()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss_dict['total_loss']
        total_contrastive += loss_dict['contrastive_loss']
        total_diversity += loss_dict['diversity_loss']
        num_batches += 1
        
        if (batch_idx + 1) % 250 == 0:
            elapsed = time.time() - start_time
            print(f"Epoch {epoch} | Batch {batch_idx + 1}/{len(dataloader)} | "
                  f"Loss: {loss_dict['total_loss']:.4f} | "
                  f"Contrastive: {loss_dict['contrastive_loss']:.4f} | "
                  f"Diversity: {loss_dict['diversity_loss']:.4f} | "
                  f"Time: {elapsed:.2f}s")
            start_time = time.time()
    
    avg_loss = total_loss / num_batches
    avg_contrastive = total_contrastive / num_batches
    avg_diversity = total_diversity / num_batches
    
    return {
        'loss': avg_loss,
        'contrastive_loss': avg_contrastive,
        'diversity_loss': avg_diversity
    }


def save_checkpoint(model, optimizer, epoch, loss, save_dir, filename=None):
    os.makedirs(save_dir, exist_ok=True)
    
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"wav2vec_epoch{epoch}_{timestamp}.pt"
    
    filepath = os.path.join(save_dir, filename)
    
    model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state,
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    torch.save(checkpoint, filepath)    
    return filepath

In [12]:
DATA_DIR = '/kaggle/input/libri100h/dev-clean'
SAVE_DIR = '/kaggle/working/models'
NUM_EPOCHS = 20
BATCH_SIZE = 2
LEARNING_RATE = 5e-4
NUM_NEGATIVES = 100
TEMPERATURE = 0.1
DIVERSITY_WEIGHT = 0.1
SAVE_EVERY = 5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print("-"*60)
print(f"Device: {DEVICE}")
print(f"Data directory: {DATA_DIR}")
print(f"Save directory: {SAVE_DIR}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print("-"*60)

------------------------------------------------------------
Device: cuda
Data directory: /kaggle/input/libri100h/dev-clean
Save directory: /kaggle/working/models
Epochs: 20
Batch size: 2
Learning rate: 0.0005
------------------------------------------------------------


In [13]:
print("\nLoading dataset")
dataset = LibriSpeechDataset(DATA_DIR)
print(f"Dataset size: {len(dataset)} samples")

num_workers = 4 if DEVICE == 'cuda' else 0
pin_memory = DEVICE == 'cuda'

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=True if num_workers > 0 else False
)


Loading dataset
Dataset size: 2703 samples


In [14]:
print("\nInitializing model")
model = Wav2Vec(
    in_channels=1,
    conv_channels=[512, 512, 512, 512, 512, 512, 512],
    kernel_sizes=[10, 3, 3, 3, 3, 2, 2],
    strides=[5, 2, 2, 2, 2, 2, 2],
    encoder_dropout=0.1,
    mask_prob=0.065,
    mask_length=10,
    embed_dim=512,
    num_transformer_layers=12,
    num_heads=8,
    ffn_dim=2048,
    transformer_dropout=0.1,
    num_groups=2,
    num_vars=320,
    temp_init=2.0,
    temp_min=0.5,
    temp_decay=0.999995,
    sample_rate=16000,
    seed=42
).to(DEVICE)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {num_params:,}")

if DEVICE == 'cuda' and torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
    model = nn.DataParallel(model)
elif DEVICE == 'cuda':
    print(f"Using single GPU")

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.01)


Initializing model
Model parameters: 45,351,552
Using 2 GPUs with DataParallel


In [15]:
best_loss = float('inf')
for epoch in range(1, NUM_EPOCHS + 1):
    print(f"EPOCH {epoch}/{NUM_EPOCHS}")
    
    metrics = train_epoch(
        model, dataloader, optimizer, DEVICE, epoch,
        num_negatives=NUM_NEGATIVES,
        temperature=TEMPERATURE,
        diversity_weight=DIVERSITY_WEIGHT
    )
    
    print(f"Epoch {epoch} Summary:")
    print(f"  Average Loss: {metrics['loss']:.4f}")
    print(f"  Contrastive Loss: {metrics['contrastive_loss']:.4f}")
    print(f"  Diversity Loss: {metrics['diversity_loss']:.4f}")
    
    if epoch % SAVE_EVERY == 0:
        save_checkpoint(model, optimizer, epoch, metrics['loss'], SAVE_DIR)
    
    if metrics['loss'] < best_loss:
        best_loss = metrics['loss']
        save_checkpoint(model, optimizer, epoch, metrics['loss'], SAVE_DIR, 
                      filename='best_model.pt')

save_checkpoint(model, optimizer, NUM_EPOCHS, metrics['loss'], SAVE_DIR, 
               filename='final_model.pt')

EPOCH 1/20




Epoch 1 | Batch 250/1352 | Loss: 4.5250 | Contrastive: 4.5266 | Diversity: -0.0166 | Time: 50.63s
Epoch 1 | Batch 500/1352 | Loss: 4.3062 | Contrastive: 4.3068 | Diversity: -0.0056 | Time: 47.04s
Epoch 1 | Batch 750/1352 | Loss: 4.6041 | Contrastive: 4.6044 | Diversity: -0.0026 | Time: 47.25s
Epoch 1 | Batch 1000/1352 | Loss: 4.6141 | Contrastive: 4.6145 | Diversity: -0.0036 | Time: 50.25s
Epoch 1 | Batch 1250/1352 | Loss: 4.6236 | Contrastive: 4.6240 | Diversity: -0.0041 | Time: 50.47s
Epoch 1 Summary:
  Average Loss: 4.5892
  Contrastive Loss: 4.5900
  Diversity Loss: -0.0074
EPOCH 2/20
Epoch 2 | Batch 250/1352 | Loss: 4.6232 | Contrastive: 4.6237 | Diversity: -0.0051 | Time: 49.99s
Epoch 2 | Batch 500/1352 | Loss: 4.6230 | Contrastive: 4.6235 | Diversity: -0.0055 | Time: 50.56s
Epoch 2 | Batch 750/1352 | Loss: 4.6163 | Contrastive: 4.6168 | Diversity: -0.0053 | Time: 50.83s
Epoch 2 | Batch 1000/1352 | Loss: 4.6133 | Contrastive: 4.6139 | Diversity: -0.0063 | Time: 47.87s
Epoch 2 | B

'/kaggle/working/models/final_model.pt'