In [None]:
from google.colab import drive
drive.mount('/content/drive')
datadir = "/content/drive/MyDrive/CS441/FP2/data/snippets"
save_dir = "/content/drive/My Drive/CS441/FP2"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torchcodec



In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
from transformers import (
    DistilBertTokenizerFast,
    DistilBertModel,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model
)
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

In [None]:
class PodcastDataset(Dataset):
    def __init__(self, csv_path, audio_root, tokenizer, feature_extractor, max_audio_len=5.0):
        self.data = pd.read_csv(csv_path)
        self.audio_root = audio_root
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_audio_len = max_audio_len

        genres = sorted(self.data["genre"].unique())
        self.label_map = {g: i for i, g in enumerate(genres)}
        print(f"üìä Loaded {len(self.data)} samples, {len(genres)} classes: {genres}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # ----- Audio -----
        audio_path = os.path.join(self.audio_root, "snippets", row["genre"], row["path"])
        if not os.path.exists(audio_path):
            raise FileNotFoundError(f"Audio not found: {audio_path}")

        wav, sr = torchaudio.load(audio_path)

        # Switched to mono
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)

        # resample to 16kHz
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            wav = resampler(wav)
            sr = 16000

        # Crop or padding to a fixed length
        target_len = int(self.max_audio_len * sr)
        if wav.size(1) > target_len:
            wav = wav[:, :target_len]
        elif wav.size(1) < target_len:
            pad_len = target_len - wav.size(1)
            wav = torch.nn.functional.pad(wav, (0, pad_len))

        # Normalized audio
        wav = wav / (wav.abs().max() + 1e-8)

        audio_features = self.feature_extractor(
            wav.squeeze().numpy(), sampling_rate=sr, return_tensors="pt"
        )

        # ----- Text -----
        text = row["transcript"]
        max_len = 256

        if pd.isna(text) or text.strip() == "":
            use_text = False
            input_ids = torch.zeros(max_len, dtype=torch.long)
            attention_mask = torch.zeros(max_len, dtype=torch.long)
        else:
            use_text = True
            encoded_text = self.tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=max_len,
                return_tensors="pt"
            )
            input_ids = encoded_text["input_ids"].squeeze(0)
            attention_mask = encoded_text["attention_mask"].squeeze(0)

        label = self.label_map[row["genre"]]

        return {
            "audio_values": audio_features["input_values"].squeeze(0),
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label),
            "use_text": torch.tensor(use_text, dtype=torch.bool)
        }

In [None]:
class ImprovedAudioTextClassifier(nn.Module):
    def __init__(self, num_classes, freeze_pretrained=False):  # unfreeze
        super().__init__()

        # Audio encoder
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        # Text encoder
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")

        if freeze_pretrained:
            for param in self.audio_encoder.parameters():
                param.requires_grad = False
            for param in self.text_encoder.parameters():
                param.requires_grad = False
        else:
            # Only fine-tune the last few layers
            # Freeze the first 4 layers of Wav2Vec2
            for i, layer in enumerate(self.audio_encoder.encoder.layers):
                if i < 4:
                    for param in layer.parameters():
                        param.requires_grad = False

            # Freeze the first 2 layers of DistilBERT
            for i, layer in enumerate(self.text_encoder.transformer.layer):
                if i < 2:
                    for param in layer.parameters():
                        param.requires_grad = False

        # Improved Projection Layer (with BatchNorm and Dropout)
        self.audio_proj = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        self.text_proj = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        # Attention Fusion Layer
        self.fusion_attention = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Tanh(),
            nn.Linear(512, 2),
            nn.Softmax(dim=1)
        )

        # Improved classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_audio, input_ids=None, attention_mask=None, use_text=None):
        # ===== audio features =====
        audio_out = self.audio_encoder(input_audio).last_hidden_state
        audio_feat = audio_out.mean(dim=1)  # [batch_size, 768]
        audio_feat = self.audio_proj(audio_feat)  # [batch_size, 512]

        # ===== text features =====
        batch_size = audio_feat.size(0)
        text_feat = torch.zeros(batch_size, 512, device=audio_feat.device, dtype=audio_feat.dtype)

        if use_text is not None:
            text_mask = use_text.bool()

            if text_mask.any():
                valid_input_ids = input_ids[text_mask]
                valid_attention_mask = attention_mask[text_mask]

                text_out = self.text_encoder(
                    input_ids=valid_input_ids,
                    attention_mask=valid_attention_mask
                ).last_hidden_state

                text_feat_valid = text_out.mean(dim=1)
                text_feat_valid = self.text_proj(text_feat_valid)
                text_feat[text_mask] = text_feat_valid

        # ===== Attention Fusion =====
        fused = torch.cat([audio_feat, text_feat], dim=1)  # [batch_size, 1024]

        # Calculate attention weights
        attn_weights = self.fusion_attention(fused)  # [batch_size, 2]

        # Weighted fusion
        weighted_audio = audio_feat * attn_weights[:, 0:1]
        weighted_text = text_feat * attn_weights[:, 1:2]
        final_feat = weighted_audio + weighted_text  # [batch_size, 512]

        logits = self.classifier(final_feat)
        return logits

In [None]:
# Train / Eval
# =====================================
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            logits = model(
                input_audio=batch["audio_values"].to(device),
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                use_text=batch["use_text"].to(device)
            )
            preds = logits.argmax(dim=1)
            labels = batch["labels"].to(device)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    return accuracy, all_preds, all_labels

In [None]:
def train_model(train_csv, val_csv, audio_root, batch_size=8, lr=2e-5, epochs=10, warmup_epochs=2):
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    train_ds = PodcastDataset(train_csv, audio_root, tokenizer, feature_extractor)
    val_ds = PodcastDataset(val_csv, audio_root, tokenizer, feature_extractor)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)

    num_classes = len(train_ds.label_map)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ImprovedAudioTextClassifier
    model = ImprovedAudioTextClassifier(num_classes=num_classes, freeze_pretrained=False).to(device)

    # different learning rate
    pretrained_params = []
    new_params = []
    for name, param in model.named_parameters():
        if 'audio_encoder' in name or 'text_encoder' in name:
            pretrained_params.append(param)
        else:
            new_params.append(param)

    optimizer = AdamW([
        {'params': pretrained_params, 'lr': lr * 0.1},  # Pre-training layer with smaller learning rate
        {'params': new_params, 'lr': lr}
    ])

    # Add a learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    criterion = nn.CrossEntropyLoss()
    best_acc = 0
    patience = 5
    no_improve = 0

    print(f"\nüöÄ Starting training...")
    print(f"Device: {device}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch in pbar:
            optimizer.zero_grad()

            logits = model(
                input_audio=batch["audio_values"].to(device),
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                use_text=batch["use_text"].to(device)
            )
            loss = criterion(logits, batch["labels"].to(device))

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\n‚ö†Ô∏è Warning: Invalid loss detected, skipping batch")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_loss = total_loss / len(train_loader)
        val_acc, _, _ = evaluate(model, val_loader, device)

        print(f">>> Epoch {epoch+1} | Train Loss: {avg_loss:.4f} | Val Acc: {val_acc:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'epoch': epoch,
                'val_acc': val_acc,
                'label_map': train_ds.label_map
            }, "best_multimodal_model.pt")
            print(f"‚úî Saved best model (Acc: {val_acc:.4f})")
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"\n‚ö†Ô∏è Early stopping after {epoch+1} epochs (no improvement for {patience} epochs)")
                break

        scheduler.step()

    print(f"\nüéâ Training complete. Best Val Acc = {best_acc:.4f}")
    return model, train_ds.label_map

In [None]:
# Test
# =====================================
def test_model(model_path, test_csv, audio_root, label_map, batch_size=8):
    """Evaluate model performance in test dataset"""
    from sklearn.metrics import classification_report, confusion_matrix
    import numpy as np

    print("\n" + "="*50)
    print("üìä Testing on Test Set")
    print("="*50 + "\n")

    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    test_ds = PodcastDataset(test_csv, audio_root, tokenizer, feature_extractor)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)

    num_classes = len(label_map)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ImprovedAudioTextClassifier(num_classes=num_classes, freeze_pretrained=False).to(device)

    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úî Loaded best model from epoch {checkpoint['epoch']+1} (Val Acc: {checkpoint['val_acc']:.4f})\n")

    test_acc, all_preds, all_labels = evaluate(model, test_loader, device)

    idx_to_label = {v: k for k, v in label_map.items()}
    class_names = [idx_to_label[i] for i in range(num_classes)]

    print(f"{'='*50}")
    print(f"üéØ Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
    print(f"{'='*50}\n")

    print("üìã Classification Report:")
    print("-" * 50)
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))

    print("\nüìä Confusion Matrix:")
    print("-" * 50)
    cm = confusion_matrix(all_labels, all_preds)

    header = "True\\Pred  " + "  ".join([f"{name[:8]:>8}" for name in class_names])
    print(header)
    print("-" * len(header))
    for i, row in enumerate(cm):
        row_str = f"{class_names[i][:10]:<10} " + "  ".join([f"{val:>8}" for val in row])
        print(row_str)

    print("\nüìà Per-Class Accuracy:")
    print("-" * 50)
    class_correct = cm.diagonal()
    class_total = cm.sum(axis=1)
    for i, name in enumerate(class_names):
        acc = class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        print(f"{name:20s}: {acc:.4f} ({class_correct[i]:3d}/{class_total[i]:3d})")

    print("\nüîÄ Most Confused Class Pairs:")
    print("-" * 50)
    confusion_pairs = []
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j and cm[i, j] > 0:
                confusion_pairs.append((cm[i, j], class_names[i], class_names[j]))

    confusion_pairs.sort(reverse=True)
    for count, true_class, pred_class in confusion_pairs[:10]:
        print(f"{true_class:15s} ‚Üí {pred_class:15s}: {count:3d} times")

    print("\n" + "="*50)
    print("‚úÖ Testing Complete!")
    print("="*50)

    return test_acc, all_preds, all_labels

In [None]:
# Run training and testing
if __name__ == "__main__":
    datadir = "/content/drive/My Drive/CS441/FP2/data/"

    # train
    model, label_map = train_model(
        train_csv=os.path.join(datadir, "train_transcripts.csv"),
        val_csv=os.path.join(datadir, "val_transcripts.csv"),
        audio_root=datadir,
        batch_size=8,
        lr=2e-5,
        epochs=20,
        warmup_epochs=2
    )

    # evaluate
    test_acc, test_preds, test_labels = test_model(
        model_path="best_multimodal_model.pt",
        test_csv=os.path.join(datadir, "test_transcripts.csv"),
        audio_root=datadir,
        label_map=label_map,
        batch_size=8
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


üìä Loaded 1080 samples, 6 classes: ['business', 'comedy', 'education', 'news', 'religion', 'sports']
üìä Loaded 240 samples, 6 classes: ['business', 'comedy', 'education', 'news', 'religion', 'sports']


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



üöÄ Starting training...
Device: cuda
Total parameters: 162,446,472
Trainable parameters: 119,919,240



Epoch 1/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.00it/s, loss=1.6857]


>>> Epoch 1 | Train Loss: 1.8040 | Val Acc: 0.3417 | LR: 2.00e-06
‚úî Saved best model (Acc: 0.3417)


Epoch 2/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.18it/s, loss=1.7868]


>>> Epoch 2 | Train Loss: 1.6961 | Val Acc: 0.4958 | LR: 1.99e-06
‚úî Saved best model (Acc: 0.4958)


Epoch 3/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.08it/s, loss=1.5980]


>>> Epoch 3 | Train Loss: 1.6163 | Val Acc: 0.5500 | LR: 1.95e-06
‚úî Saved best model (Acc: 0.5500)


Epoch 4/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.27it/s, loss=1.4534]


>>> Epoch 4 | Train Loss: 1.4760 | Val Acc: 0.5708 | LR: 1.89e-06
‚úî Saved best model (Acc: 0.5708)


Epoch 5/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.07it/s, loss=1.3061]


>>> Epoch 5 | Train Loss: 1.3538 | Val Acc: 0.5792 | LR: 1.81e-06
‚úî Saved best model (Acc: 0.5792)


Epoch 6/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.07it/s, loss=1.5471]


>>> Epoch 6 | Train Loss: 1.2811 | Val Acc: 0.5917 | LR: 1.71e-06
‚úî Saved best model (Acc: 0.5917)


Epoch 7/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.07it/s, loss=1.1430]


>>> Epoch 7 | Train Loss: 1.1831 | Val Acc: 0.5958 | LR: 1.59e-06
‚úî Saved best model (Acc: 0.5958)


Epoch 8/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.10it/s, loss=1.0811]


>>> Epoch 8 | Train Loss: 1.1534 | Val Acc: 0.5833 | LR: 1.45e-06


Epoch 9/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.31it/s, loss=0.7985]


>>> Epoch 9 | Train Loss: 1.0732 | Val Acc: 0.5792 | LR: 1.31e-06


Epoch 10/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.28it/s, loss=1.1725]


>>> Epoch 10 | Train Loss: 1.0263 | Val Acc: 0.6083 | LR: 1.16e-06
‚úî Saved best model (Acc: 0.6083)


Epoch 11/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.06it/s, loss=0.7862]


>>> Epoch 11 | Train Loss: 1.0094 | Val Acc: 0.5875 | LR: 1.00e-06


Epoch 12/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.30it/s, loss=0.7405]


>>> Epoch 12 | Train Loss: 0.9625 | Val Acc: 0.6000 | LR: 8.44e-07


Epoch 13/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.29it/s, loss=0.8308]


>>> Epoch 13 | Train Loss: 0.9410 | Val Acc: 0.5958 | LR: 6.91e-07


Epoch 14/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.24it/s, loss=1.4744]


>>> Epoch 14 | Train Loss: 0.9180 | Val Acc: 0.6125 | LR: 5.46e-07
‚úî Saved best model (Acc: 0.6125)


Epoch 15/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.30it/s, loss=1.1453]


>>> Epoch 15 | Train Loss: 0.8972 | Val Acc: 0.6125 | LR: 4.12e-07


Epoch 16/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.28it/s, loss=0.9010]


>>> Epoch 16 | Train Loss: 0.9094 | Val Acc: 0.6208 | LR: 2.93e-07
‚úî Saved best model (Acc: 0.6208)


Epoch 17/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.18it/s, loss=0.8812]


>>> Epoch 17 | Train Loss: 0.8770 | Val Acc: 0.5917 | LR: 1.91e-07


Epoch 18/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.24it/s, loss=0.8030]


>>> Epoch 18 | Train Loss: 0.8858 | Val Acc: 0.6083 | LR: 1.09e-07


Epoch 19/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.23it/s, loss=0.3825]


>>> Epoch 19 | Train Loss: 0.8592 | Val Acc: 0.6083 | LR: 4.89e-08


Epoch 20/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:31<00:00,  4.29it/s, loss=1.0091]


>>> Epoch 20 | Train Loss: 0.8591 | Val Acc: 0.6042 | LR: 1.23e-08

üéâ Training complete. Best Val Acc = 0.6208

üìä Testing on Test Set

üìä Loaded 240 samples, 6 classes: ['business', 'comedy', 'education', 'news', 'religion', 'sports']


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úî Loaded best model from epoch 16 (Val Acc: 0.6208)

üéØ Test Accuracy: 0.7750 (77.50%)

üìã Classification Report:
--------------------------------------------------
              precision    recall  f1-score   support

    business     0.6415    0.8500    0.7312        40
      comedy     0.7297    0.6750    0.7013        40
   education     0.6667    0.8500    0.7473        40
        news     0.8571    0.4500    0.5902        40
    religion     0.9756    1.0000    0.9877        40
      sports     0.8919    0.8250    0.8571        40

    accuracy                         0.7750       240
   macro avg     0.7938    0.7750    0.7691       240
weighted avg     0.7938    0.7750    0.7691       240


üìä Confusion Matrix:
--------------------------------------------------
True\Pred  business    comedy  educatio      news  religion    sports
---------------------------------------------------------------------
business         34         1         4         1         0         0
c

Tested

In [None]:
# Test Accuracy: 0.7875 (78.75%)
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
from transformers import (
    DistilBertTokenizerFast,
    DistilBertModel,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model
)
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

# =====================================
# Dataset (ÊîπËøõÁâà)
# =====================================
class PodcastDataset(Dataset):
    def __init__(self, csv_path, audio_root, tokenizer, feature_extractor, max_audio_len=5.0):
        self.data = pd.read_csv(csv_path)
        self.audio_root = audio_root
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_audio_len = max_audio_len  # ÊúÄÂ§ßÈü≥È¢ëÈïøÂ∫¶ÔºàÁßíÔºâ

        genres = sorted(self.data["genre"].unique())
        self.label_map = {g: i for i, g in enumerate(genres)}
        print(f"üìä Loaded {len(self.data)} samples, {len(genres)} classes: {genres}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # ----- Audio -----
        audio_path = os.path.join(self.audio_root, "snippets", row["genre"], row["path"])
        if not os.path.exists(audio_path):
            raise FileNotFoundError(f"Audio not found: {audio_path}")

        wav, sr = torchaudio.load(audio_path)

        # ËΩ¨‰∏∫ÂçïÂ£∞ÈÅì
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)

        # ÈáçÈááÊ†∑Âà∞16kHz
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            wav = resampler(wav)
            sr = 16000

        # Ë£ÅÂâ™ÊàñÂ°´ÂÖÖÂà∞Âõ∫ÂÆöÈïøÂ∫¶
        target_len = int(self.max_audio_len * sr)
        if wav.size(1) > target_len:
            wav = wav[:, :target_len]
        elif wav.size(1) < target_len:
            pad_len = target_len - wav.size(1)
            wav = torch.nn.functional.pad(wav, (0, pad_len))

        # ÂΩí‰∏ÄÂåñÈü≥È¢ë
        wav = wav / (wav.abs().max() + 1e-8)

        audio_features = self.feature_extractor(
            wav.squeeze().numpy(), sampling_rate=sr, return_tensors="pt"
        )

        # ----- Text -----
        text = row["transcript"]
        max_len = 256

        if pd.isna(text) or text.strip() == "":
            use_text = False
            input_ids = torch.zeros(max_len, dtype=torch.long)
            attention_mask = torch.zeros(max_len, dtype=torch.long)
        else:
            use_text = True
            encoded_text = self.tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=max_len,
                return_tensors="pt"
            )
            input_ids = encoded_text["input_ids"].squeeze(0)
            attention_mask = encoded_text["attention_mask"].squeeze(0)

        label = self.label_map[row["genre"]]

        return {
            "audio_values": audio_features["input_values"].squeeze(0),
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label),
            "use_text": torch.tensor(use_text, dtype=torch.bool)
        }


# =====================================
# ÊîπËøõÁöÑÊ®°ÂûãÊû∂ÊûÑ
# =====================================
class ImprovedAudioTextClassifier(nn.Module):
    def __init__(self, num_classes, freeze_pretrained=False):  # Êîπ‰∏∫‰∏çÂÜªÁªì
        super().__init__()

        # Audio encoder
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        # Text encoder
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")

        if freeze_pretrained:
            for param in self.audio_encoder.parameters():
                param.requires_grad = False
            for param in self.text_encoder.parameters():
                param.requires_grad = False
        else:
            # Âè™ÂæÆË∞ÉÊúÄÂêéÂá†Â±Ç
            # ÂÜªÁªì Wav2Vec2 ÁöÑÂâç 8 Â±Ç
            for i, layer in enumerate(self.audio_encoder.encoder.layers):
                if i < 4:
                    for param in layer.parameters():
                        param.requires_grad = False

            # ÂÜªÁªì DistilBERT ÁöÑÂâç 4 Â±Ç
            for i, layer in enumerate(self.text_encoder.transformer.layer):
                if i < 2:
                    for param in layer.parameters():
                        param.requires_grad = False

        # ÊîπËøõÁöÑÊäïÂΩ±Â±ÇÔºàÂ∏¶BatchNormÂíåDropoutÔºâ
        self.audio_proj = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        self.text_proj = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        # Ê≥®ÊÑèÂäõËûçÂêàÂ±ÇÔºàÊõø‰ª£ÁÆÄÂçïÊãºÊé•Ôºâ
        self.fusion_attention = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Tanh(),
            nn.Linear(512, 2),
            nn.Softmax(dim=1)
        )

        # ÊîπËøõÁöÑÂàÜÁ±ªÂô®
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_audio, input_ids=None, attention_mask=None, use_text=None):
        # ===== Èü≥È¢ëÁâπÂæÅ =====
        audio_out = self.audio_encoder(input_audio).last_hidden_state
        audio_feat = audio_out.mean(dim=1)  # [batch_size, 768]
        audio_feat = self.audio_proj(audio_feat)  # [batch_size, 512]

        # ===== ÊñáÊú¨ÁâπÂæÅ =====
        batch_size = audio_feat.size(0)
        text_feat = torch.zeros(batch_size, 512, device=audio_feat.device, dtype=audio_feat.dtype)

        if use_text is not None:
            text_mask = use_text.bool()

            if text_mask.any():
                valid_input_ids = input_ids[text_mask]
                valid_attention_mask = attention_mask[text_mask]

                text_out = self.text_encoder(
                    input_ids=valid_input_ids,
                    attention_mask=valid_attention_mask
                ).last_hidden_state

                text_feat_valid = text_out.mean(dim=1)
                text_feat_valid = self.text_proj(text_feat_valid)
                text_feat[text_mask] = text_feat_valid

        # ===== Ê≥®ÊÑèÂäõËûçÂêàÔºàÊõø‰ª£ÁÆÄÂçïÊãºÊé•Ôºâ=====
        fused = torch.cat([audio_feat, text_feat], dim=1)  # [batch_size, 1024]

        # ËÆ°ÁÆóÊ≥®ÊÑèÂäõÊùÉÈáç
        attn_weights = self.fusion_attention(fused)  # [batch_size, 2]

        # Âä†ÊùÉËûçÂêà
        weighted_audio = audio_feat * attn_weights[:, 0:1]
        weighted_text = text_feat * attn_weights[:, 1:2]
        final_feat = weighted_audio + weighted_text  # [batch_size, 512]

        logits = self.classifier(final_feat)
        return logits


# =====================================
# Train / EvalÔºàÊîπËøõÁâàÔºâ
# =====================================
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            logits = model(
                input_audio=batch["audio_values"].to(device),
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                use_text=batch["use_text"].to(device)
            )
            preds = logits.argmax(dim=1)
            labels = batch["labels"].to(device)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    return accuracy, all_preds, all_labels


def train_model(train_csv, val_csv, audio_root, batch_size=8, lr=2e-5, epochs=10, warmup_epochs=2):
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    train_ds = PodcastDataset(train_csv, audio_root, tokenizer, feature_extractor)
    val_ds = PodcastDataset(val_csv, audio_root, tokenizer, feature_extractor)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)

    num_classes = len(train_ds.label_map)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ‰ΩøÁî®ÊîπËøõÁöÑÊ®°Âûã
    model = ImprovedAudioTextClassifier(num_classes=num_classes, freeze_pretrained=False).to(device)

    # ÂàÜÁªÑÂ≠¶‰π†Áéá
    pretrained_params = []
    new_params = []
    for name, param in model.named_parameters():
        if 'audio_encoder' in name or 'text_encoder' in name:
            pretrained_params.append(param)
        else:
            new_params.append(param)

    optimizer = AdamW([
        {'params': pretrained_params, 'lr': lr * 0.1},  # È¢ÑËÆ≠ÁªÉÂ±ÇÁî®Êõ¥Â∞èÂ≠¶‰π†Áéá
        {'params': new_params, 'lr': lr}
    ])

    # Ê∑ªÂä†Â≠¶‰π†ÁéáË∞ÉÂ∫¶Âô®
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    criterion = nn.CrossEntropyLoss()
    best_acc = 0
    patience = 5
    no_improve = 0

    print(f"\nüöÄ Starting training...")
    print(f"Device: {device}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch in pbar:
            optimizer.zero_grad()

            logits = model(
                input_audio=batch["audio_values"].to(device),
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                use_text=batch["use_text"].to(device)
            )
            loss = criterion(logits, batch["labels"].to(device))

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\n‚ö†Ô∏è Warning: Invalid loss detected, skipping batch")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_loss = total_loss / len(train_loader)
        val_acc, _, _ = evaluate(model, val_loader, device)

        print(f">>> Epoch {epoch+1} | Train Loss: {avg_loss:.4f} | Val Acc: {val_acc:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'epoch': epoch,
                'val_acc': val_acc,
                'label_map': train_ds.label_map
            }, "best_multimodal_model.pt")
            print(f"‚úî Saved best model (Acc: {val_acc:.4f})")
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"\n‚ö†Ô∏è Early stopping after {epoch+1} epochs (no improvement for {patience} epochs)")
                break

        scheduler.step()

    print(f"\nüéâ Training complete. Best Val Acc = {best_acc:.4f}")
    return model, train_ds.label_map


# =====================================
# TestÂáΩÊï∞
# =====================================
def test_model(model_path, test_csv, audio_root, label_map, batch_size=8):
    """Âú®ÊµãËØïÈõÜ‰∏äËØÑ‰º∞Ê®°ÂûãÊÄßËÉΩ"""
    from sklearn.metrics import classification_report, confusion_matrix
    import numpy as np

    print("\n" + "="*50)
    print("üìä Testing on Test Set")
    print("="*50 + "\n")

    # Âä†ËΩΩÊ®°Âûã
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    test_ds = PodcastDataset(test_csv, audio_root, tokenizer, feature_extractor)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)

    num_classes = len(label_map)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ImprovedAudioTextClassifier(num_classes=num_classes, freeze_pretrained=False).to(device)

    # Âä†ËΩΩÊúÄ‰Ω≥Ê®°ÂûãÊùÉÈáç
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úî Loaded best model from epoch {checkpoint['epoch']+1} (Val Acc: {checkpoint['val_acc']:.4f})\n")

    # ËØÑ‰º∞
    test_acc, all_preds, all_labels = evaluate(model, test_loader, device)

    # ÂèçËΩ¨ label_map ‰ª•Ëé∑ÂèñÁ±ªÂà´ÂêçÁß∞
    idx_to_label = {v: k for k, v in label_map.items()}
    class_names = [idx_to_label[i] for i in range(num_classes)]

    # ÊâìÂç∞ÊÄª‰ΩìÁªìÊûú
    print(f"{'='*50}")
    print(f"üéØ Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
    print(f"{'='*50}\n")

    # ËØ¶ÁªÜÂàÜÁ±ªÊä•Âëä
    print("üìã Classification Report:")
    print("-" * 50)
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))

    # Ê∑∑Ê∑ÜÁü©Èòµ
    print("\nüìä Confusion Matrix:")
    print("-" * 50)
    cm = confusion_matrix(all_labels, all_preds)

    # ÊâìÂç∞Ê†ºÂºèÂåñÁöÑÊ∑∑Ê∑ÜÁü©Èòµ
    header = "True\\Pred  " + "  ".join([f"{name[:8]:>8}" for name in class_names])
    print(header)
    print("-" * len(header))
    for i, row in enumerate(cm):
        row_str = f"{class_names[i][:10]:<10} " + "  ".join([f"{val:>8}" for val in row])
        print(row_str)

    # ÊØè‰∏™Á±ªÂà´ÁöÑÂáÜÁ°ÆÁéá
    print("\nüìà Per-Class Accuracy:")
    print("-" * 50)
    class_correct = cm.diagonal()
    class_total = cm.sum(axis=1)
    for i, name in enumerate(class_names):
        acc = class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        print(f"{name:20s}: {acc:.4f} ({class_correct[i]:3d}/{class_total[i]:3d})")

    # ÊúÄÂÆπÊòìÊ∑∑Ê∑ÜÁöÑÁ±ªÂà´ÂØπ
    print("\nüîÄ Most Confused Class Pairs:")
    print("-" * 50)
    confusion_pairs = []
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j and cm[i, j] > 0:
                confusion_pairs.append((cm[i, j], class_names[i], class_names[j]))

    confusion_pairs.sort(reverse=True)
    for count, true_class, pred_class in confusion_pairs[:10]:
        print(f"{true_class:15s} ‚Üí {pred_class:15s}: {count:3d} times")

    print("\n" + "="*50)
    print("‚úÖ Testing Complete!")
    print("="*50)

    return test_acc, all_preds, all_labels


# =====================================
# Run training and testing
# =====================================
if __name__ == "__main__":
    datadir = "/content/drive/My Drive/CS441/FP2/data/"

    # ËÆ≠ÁªÉÊ®°Âûã
    model, label_map = train_model(
        train_csv=os.path.join(datadir, "train_transcripts.csv"),
        val_csv=os.path.join(datadir, "val_transcripts.csv"),
        audio_root=datadir,
        batch_size=8,
        lr=2e-5,
        epochs=20,
        warmup_epochs=2
    )

    # Âú®ÊµãËØïÈõÜ‰∏äËØÑ‰º∞
    test_acc, test_preds, test_labels = test_model(
        model_path="best_multimodal_model.pt",
        test_csv=os.path.join(datadir, "test_transcripts.csv"),
        audio_root=datadir,
        label_map=label_map,
        batch_size=8
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


üìä Loaded 1080 samples, 6 classes: ['business', 'comedy', 'education', 'news', 'religion', 'sports']
üìä Loaded 240 samples, 6 classes: ['business', 'comedy', 'education', 'news', 'religion', 'sports']


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



üöÄ Starting training...
Device: cuda
Total parameters: 162,446,472
Trainable parameters: 119,919,240



Epoch 1/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:34<00:00,  3.96it/s, loss=1.7530]


>>> Epoch 1 | Train Loss: 1.8151 | Val Acc: 0.3125 | LR: 2.00e-06
‚úî Saved best model (Acc: 0.3125)


Epoch 2/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.01it/s, loss=1.6621]


>>> Epoch 2 | Train Loss: 1.7149 | Val Acc: 0.5083 | LR: 1.99e-06
‚úî Saved best model (Acc: 0.5083)


Epoch 3/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:34<00:00,  3.93it/s, loss=1.3811]


>>> Epoch 3 | Train Loss: 1.6056 | Val Acc: 0.5583 | LR: 1.95e-06
‚úî Saved best model (Acc: 0.5583)


Epoch 4/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.11it/s, loss=1.5300]


>>> Epoch 4 | Train Loss: 1.4858 | Val Acc: 0.5542 | LR: 1.89e-06


Epoch 5/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.13it/s, loss=1.4863]


>>> Epoch 5 | Train Loss: 1.3547 | Val Acc: 0.5500 | LR: 1.81e-06


Epoch 6/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.12it/s, loss=0.9473]


>>> Epoch 6 | Train Loss: 1.2586 | Val Acc: 0.5417 | LR: 1.71e-06


Epoch 7/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.16it/s, loss=1.1882]


>>> Epoch 7 | Train Loss: 1.1893 | Val Acc: 0.5417 | LR: 1.59e-06


Epoch 8/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.16it/s, loss=0.9935]


>>> Epoch 8 | Train Loss: 1.1412 | Val Acc: 0.5708 | LR: 1.45e-06
‚úî Saved best model (Acc: 0.5708)


Epoch 9/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:34<00:00,  3.96it/s, loss=1.3785]


>>> Epoch 9 | Train Loss: 1.0776 | Val Acc: 0.5792 | LR: 1.31e-06
‚úî Saved best model (Acc: 0.5792)


Epoch 10/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:34<00:00,  3.90it/s, loss=1.3659]


>>> Epoch 10 | Train Loss: 1.0407 | Val Acc: 0.6000 | LR: 1.16e-06
‚úî Saved best model (Acc: 0.6000)


Epoch 11/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:33<00:00,  4.06it/s, loss=1.0686]


>>> Epoch 11 | Train Loss: 1.0012 | Val Acc: 0.5958 | LR: 1.00e-06


Epoch 12/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.16it/s, loss=0.7608]


>>> Epoch 12 | Train Loss: 0.9661 | Val Acc: 0.5792 | LR: 8.44e-07


Epoch 13/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.15it/s, loss=1.1121]


>>> Epoch 13 | Train Loss: 0.9539 | Val Acc: 0.6000 | LR: 6.91e-07


Epoch 14/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.17it/s, loss=1.0329]


>>> Epoch 14 | Train Loss: 0.8933 | Val Acc: 0.5917 | LR: 5.46e-07


Epoch 15/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 135/135 [00:32<00:00,  4.18it/s, loss=1.1337]


>>> Epoch 15 | Train Loss: 0.9118 | Val Acc: 0.6000 | LR: 4.12e-07

‚ö†Ô∏è Early stopping after 15 epochs (no improvement for 5 epochs)

üéâ Training complete. Best Val Acc = 0.6000

üìä Testing on Test Set

üìä Loaded 240 samples, 6 classes: ['business', 'comedy', 'education', 'news', 'religion', 'sports']


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úî Loaded best model from epoch 10 (Val Acc: 0.6000)

üéØ Test Accuracy: 0.7875 (78.75%)

üìã Classification Report:
--------------------------------------------------
              precision    recall  f1-score   support

    business     0.6981    0.9250    0.7957        40
      comedy     0.7105    0.6750    0.6923        40
   education     0.8000    0.7000    0.7467        40
        news     0.7097    0.5500    0.6197        40
    religion     0.9524    1.0000    0.9756        40
      sports     0.8537    0.8750    0.8642        40

    accuracy                         0.7875       240
   macro avg     0.7874    0.7875    0.7824       240
weighted avg     0.7874    0.7875    0.7824       240


üìä Confusion Matrix:
--------------------------------------------------
True\Pred  business    comedy  educatio      news  religion    sports
---------------------------------------------------------------------
business         37         0         2         1         0         0
c