In [None]:
# notebook_demo.ipynb

# ---
# # Twitch Multimodal Demo
#
# This notebook demonstrates how to use the code in the "my-twitch-multimodal" repository.
# ---

import torch
import os
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

# 1) Imports from our local modules
from data_utils import build_vocabulary, TwitchCommentDataset, my_collate_fn
from data_utils import load_chat  # optional
from preprocess import precompute_features
from model import MultiModalLSTM
from train import train_one_epoch
from inference import generate_comment

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example paths (adapt these to your folder structure)
CHAT_FILE = "data/chat.txt"
VIDEO_FILE = "data/video.mp4"
AUDIO_FILE = "data/audio.wav"
CACHE_DIR  = "cached_features"
MODEL_PATH = "my_multimodal_model.pth"

# 2) Build vocab (with special tokens)
special_toks = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"]
word2idx, idx2word = build_vocabulary(
    chat_file=CHAT_FILE,
    min_freq=1,
    max_size=5000,
    special_tokens=special_toks
)
print("Vocabulary size:", len(word2idx))

# 3) Precompute features (if not done yet)
if not os.path.exists(CACHE_DIR):
    os.makedirs(CACHE_DIR, exist_ok=True)
    precompute_features(
        video_path=VIDEO_FILE,
        audio_path=AUDIO_FILE,
        chat_file=CHAT_FILE,
        output_dir=CACHE_DIR
    )

# 4) Create the dataset/dataloader
dataset = TwitchCommentDataset(
    cache_dir=CACHE_DIR,
    chat_file=CHAT_FILE,
    word2idx=word2idx
)
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=my_collate_fn,
)

# 5) Initialize model
video_feature_dim = 3*224*224
audio_feature_dim = 64
vocab_size = len(word2idx)
model = MultiModalLSTM(vocab_size, video_feature_dim, audio_feature_dim, hidden_dim=512).to(device)

# 6) Train or load model
train_model = True  # set to False to skip training

if train_model:
    # A simple training loop
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignoring <PAD>=0 if that's the scheme

    for epoch in range(2):  # just 2 epochs for demo
        loss_val = train_one_epoch(model, dataloader, optimizer, criterion, epoch+1, device)
        print(f"Epoch {epoch+1}, Loss={loss_val:.4f}")

    # save
    torch.save(model.state_dict(), MODEL_PATH)
    print("Saved model to:", MODEL_PATH)
else:
    # load
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint)
    print("Loaded model from:", MODEL_PATH)

# 7) Inference example
sample = dataset[0]
video_sample = sample['video']
audio_sample = sample['audio']

start_idx = word2idx.get("<SOS>", 0)
end_idx   = word2idx.get("<EOS>", 0)

gen_ids = generate_comment(
    model=model,
    video_tensor=video_sample,
    audio_tensor=audio_sample,
    start_token_idx=start_idx,
    end_token_idx=end_idx,
    max_len=20,
    device=device,
    temperature=0.8,
    top_k=5
)

# Convert IDs to text
rev_vocab = {v: k for k, v in word2idx.items()}
decoded = [rev_vocab.get(tok_id, "<UNK>") for tok_id in gen_ids]
print("Generated Comment:", " ".join(decoded))
