<a href="https://colab.research.google.com/github/lazywriter/vades_implementation/blob/main/Copy_of_vades_implementation_v4_yash.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================================
# SETUP
# ============================================================================

!pip install -q sentence-transformers spacy textstat joblib scikit-learn beautifulsoup4
!python -m spacy download en_core_web_sm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import os, re, string, random, joblib
from collections import defaultdict, Counter
from tqdm import tqdm
from sklearn.svm import SVR
from sklearn.model_selection import cross_val_score, train_test_split
from scipy.spatial.distance import pdist
import spacy, textstat
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
!pip install -U sentence-transformers


from google.colab import drive
drive.mount('/content/drive')

DRIVE_ROOT = "/content/drive/MyDrive/final_year_project_books_folder"
BOOKS_DIR = os.path.join(DRIVE_ROOT, "gutenberg_books")
CACHE_DIR = os.path.join(DRIVE_ROOT, "feat_cache")
CKPT_DIR = os.path.join(DRIVE_ROOT, "vades_final_rebuild")
SPLIT_DIR = os.path.join(DRIVE_ROOT, "splits_final")
OUTPUT_CSV = os.path.join(DRIVE_ROOT, "gutenberg_dataset_clean.csv")

for d in [CACHE_DIR, CKPT_DIR, SPLIT_DIR]:
    os.makedirs(d, exist_ok=True)

RND = 42
random.seed(RND)
np.random.seed(RND)
torch.manual_seed(RND)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# ============================================================================
# DATA LOADING (WITH CACHING)
# ============================================================================

def extract_text_from_html(path):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        soup = BeautifulSoup(f, "html.parser")
        for tag in soup(["table", "script", "style"]): tag.extract()
        return soup.get_text(separator=" ", strip=True)

def clean_gutenberg_text(text):
    text = re.sub(r'(?is).*?(the project gutenberg ebook of )', '', text, count=1)
    text = re.sub(r'(?is)end of (the )?project gutenberg ebook.*', '', text)
    text = re.sub(r'(?is)\*\*\*.*?project gutenberg.*?\*\*\*', '', text)
    text = re.sub(r'(?i)(produced by|transcribed from).*?\n', '', text)
    text = re.sub(r'\b(CHAPTER|Chapter)\s+[A-Z0-9IVX]+\b', '', text)
    text = re.sub(r'[_*#~-]{2,}', ' ', text)
    text = re.sub(r'http\S+', '', text)
    return re.sub(r'\s+', ' ', text).strip()

def chunk_text(text, size=4000, overlap=500):
    chunks, start = [], 0
    while start < len(text):
        chunks.append(text[start:start+size])
        start += size - overlap
    return chunks

if os.path.exists(OUTPUT_CSV):
    df = pd.read_csv(OUTPUT_CSV)
    print(f"✓ Loaded {len(df)} chunks")
else:
    data = []
    for author in tqdm(os.listdir(BOOKS_DIR)):
        folder = os.path.join(BOOKS_DIR, author)
        if not os.path.isdir(folder): continue
        for fname in os.listdir(folder):
            if not fname.endswith(".html"): continue
            try:
                raw = extract_text_from_html(os.path.join(folder, fname))
                clean = clean_gutenberg_text(raw)
                for chunk in chunk_text(clean):
                    if chunk.strip():
                        data.append({"author": author, "file_name": fname, "text": chunk})
            except: pass
    df = pd.DataFrame(data)
    df.to_csv(OUTPUT_CSV, index=False)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/176.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.4/176.4 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.1/2.1 MB[0m [31m107.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m149.6 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('

ValueError: mount failed

In [None]:
# ============================================================================
# BOOK-LEVEL SPLIT (NO DATA LEAKAGE)
# ============================================================================

train_csv, val_csv, test_csv = [os.path.join(SPLIT_DIR, f"{s}.csv") for s in ["train", "val", "test"]]

if all(os.path.exists(f) for f in [train_csv, val_csv, test_csv]):
    train_df, val_df, test_df = pd.read_csv(train_csv), pd.read_csv(val_csv), pd.read_csv(test_csv)
    print(f"✓ Loaded splits")
else:
    df['book_id'] = df['author'] + '___' + df['file_name']
    books = df[['author', 'book_id']].drop_duplicates()

    train_books, temp = train_test_split(books, test_size=0.3, stratify=books['author'], random_state=RND)
    val_books, test_books = train_test_split(temp, test_size=0.5, stratify=temp['author'], random_state=RND)

    train_df = df[df['book_id'].isin(train_books['book_id'])].reset_index(drop=True)
    val_df = df[df['book_id'].isin(val_books['book_id'])].reset_index(drop=True)
    test_df = df[df['book_id'].isin(test_books['book_id'])].reset_index(drop=True)

    # Verify no leakage
    assert len(set(train_df['book_id']) & set(test_df['book_id'])) == 0

    train_df.to_csv(train_csv, index=False)
    val_df.to_csv(val_csv, index=False)
    test_df.to_csv(test_csv, index=False)

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

✓ Loaded splits
Train: 10809, Val: 2454, Test: 1935


In [None]:
# ============================================================================
# STYLE FEATURES (300-D, WITH CACHING)
# ============================================================================

nlp = spacy.load("en_core_web_sm", disable=["parser"])
FUNCTION_WORDS = ["the","be","to","of","and","a","in","that","have","i","it","for","not","on","with","he","as","you","do","at","this","but","his","by","from","they","we","say","her","she","or","an","will","my","one","all","would","there","their","what","so","up","out","if","about","who","get","which","go","me","when","make","can","like","time","no","just","him","know","take"]

def build_style_vector(text, dim=300):
    feats = []
    words = text.split()
    nw = len(words)

    # Basic (5)
    feats += [len(text), nw, max(1, len(re.split(r'[.!?]+', text))), nw/max(1,len(re.split(r'[.!?]+',text))), sum(len(w) for w in words)/max(1,nw)]

    # Punctuation (11)
    punc_counts = Counter(ch for ch in text if ch in string.punctuation)
    for k in ['.',',',';','!','?',"'",'"',':','-','(',')']: feats.append(punc_counts.get(k,0)/max(1,nw))

    # Function words (60)
    word_lower = [w.lower() for w in re.findall(r"\w+", text)]
    wc = Counter(word_lower)
    feats += [wc.get(fw,0)/max(1,len(word_lower)) for fw in FUNCTION_WORDS[:60]]

    # Lexical (3)
    types = set(word_lower)
    feats += [len(types)/max(1,len(word_lower)), sum(1 for _,c in wc.items() if c==1)/max(1,nw), 0.0]

    # Syllables (2)
    try: sylls = textstat.syllable_count(text)
    except: sylls = sum(max(1,sum(1 for c in w if c.lower() in'aeiou')) for w in words)
    feats += [sylls, sylls/max(1,nw)]

    # Readability (4)
    try: feats += [textstat.flesch_reading_ease(text), textstat.flesch_kincaid_grade(text), textstat.smog_index(text), textstat.coleman_liau_index(text)]
    except: feats += [0.0]*4

    # POS+NER (28)
    doc = nlp(text)
    pos_c = Counter(t.pos_ for t in doc)
    ner_c = Counter(e.label_ for e in doc.ents)
    for k in ['NOUN','VERB','ADJ','ADV','PRON','DET','ADP','NUM','PROPN','PART','INTJ','SYM','X','PUNCT']:
        feats.append(pos_c.get(k,0)/max(1,len(doc)))
    for k in ['PERSON','NORP','ORG','GPE','LOC','PRODUCT','EVENT','WORK_OF_ART','LAW','LANGUAGE','DATE','TIME','PERCENT','MONEY']:
        feats.append(ner_c.get(k,0)/max(1,len(doc)))

    # Char n-grams (50)
    s = re.sub(r'\s+','',text.lower())
    char_ng = Counter()
    for n in [2,3,4]:
        for i in range(len(s)-n+1): char_ng[s[i:i+n]]+=1
    top = [g for g,_ in char_ng.most_common(50)]
    tot = max(1,sum(char_ng.values()))
    feats += [char_ng.get(g,0)/tot for g in top]
    if len(feats) < 85+50: feats += [0.0]*(85+50-len(feats))

    # Word bigrams (50)
    word_bg = Counter()
    for i in range(len(word_lower)-1): word_bg[" ".join(word_lower[i:i+2])]+=1
    top = [g for g,_ in word_bg.most_common(50)]
    tot = max(1,sum(word_bg.values()))
    feats += [word_bg.get(g,0)/tot for g in top]
    if len(feats) < 185: feats += [0.0]*(185-len(feats))

    # Distributional (2)
    probs = np.array([v/max(1,sum(wc.values())) for v in wc.values()], dtype=float)
    feats += [-np.sum(probs*np.log2(probs+1e-12)), max(wc.values())/max(1,sum(wc.values())) if wc else 0.0]

    # Extra (3)
    feats += [sum(1 for c in text if c.isdigit())/max(1,len(text)), sum(1 for w in words if w.isupper())/max(1,nw), sum(1 for w in word_lower if w in spacy.lang.en.stop_words.STOP_WORDS)/max(1,len(word_lower))]

    feats = feats[:dim] + [0.0]*(dim-len(feats))
    return np.array(feats, dtype=np.float32)

def cache_features(df, path):
    if os.path.exists(path):
        print(f"✓ Loading {os.path.basename(path)}")
        return joblib.load(path)
    feats = [build_style_vector(t) for t in tqdm(df['text'], desc="Features")]
    joblib.dump(np.stack(feats), path)
    return np.stack(feats)

train_feats = cache_features(train_df, os.path.join(CACHE_DIR, "train_feats_rebuild.joblib"))
val_feats = cache_features(val_df, os.path.join(CACHE_DIR, "val_feats_rebuild.joblib"))
test_feats = cache_features(test_df, os.path.join(CACHE_DIR, "test_feats_rebuild.joblib"))

# Embeddings - UPDATED MODEL
emb_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device=device)

def cache_embeddings(df, path):
    if os.path.exists(path):
        print(f"✓ Loading {os.path.basename(path)}")
        return joblib.load(path)
    embs = emb_model.encode(df['text'].tolist(), show_progress_bar=True, batch_size=64, convert_to_numpy=True)
    joblib.dump(embs, path)
    return embs

train_embs = cache_embeddings(train_df, os.path.join(CACHE_DIR, "train_embs_rebuild.joblib"))
val_embs = cache_embeddings(val_df, os.path.join(CACHE_DIR, "val_embs_rebuild.joblib"))
test_embs = cache_embeddings(test_df, os.path.join(CACHE_DIR, "test_embs_rebuild.joblib"))

print(f"Features: {train_feats.shape}, Embeddings: {train_embs.shape}")


✓ Loading train_feats_rebuild.joblib
✓ Loading val_feats_rebuild.joblib
✓ Loading test_feats_rebuild.joblib
✓ Loading train_embs_rebuild.joblib
✓ Loading val_embs_rebuild.joblib
✓ Loading test_embs_rebuild.joblib
Features: (10809, 300), Embeddings: (10809, 768)


In [None]:
# ============================================================================
# VADES MODEL (SIMPLIFIED, WORKING VERSION)
# ============================================================================

class VADES(nn.Module):
    def __init__(self, emb_dim=768, style_dim=300, hidden=512, n_authors=10):
        super().__init__()

        # Document encoder
        self.encoder = nn.Sequential(
            nn.Linear(emb_dim, hidden), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(hidden, style_dim)
        )

        # Author embeddings (simple, no variance)
        self.authors = nn.Embedding(n_authors, style_dim)
        nn.init.xavier_uniform_(self.authors.weight)

    def forward(self, doc_emb):
        return self.encoder(doc_emb)

# Simple contrastive loss (no VIB complexity)
def simple_contrastive_loss(model, doc_emb, author_ids, style_feats, neg_author_ids, neg_style_feats, alpha=0.1, margin=1.0):

    # Get representations
    doc_emb_out = model(doc_emb)  # [batch, 300]
    author_emb = model.authors(author_ids)  # [batch, 300]

    # Positive distances (should be small)
    pos_dist_author = torch.sum((doc_emb_out - author_emb) ** 2, dim=1)  # [batch]
    pos_dist_style = torch.sum((doc_emb_out - style_feats) ** 2, dim=1)  # [batch]

    # Negative distances (should be large)
    batch_size = doc_emb.size(0)
    k_neg = neg_author_ids.size(1)

    neg_author_emb = model.authors(neg_author_ids.view(-1)).view(batch_size, k_neg, -1)  # [batch, k, 300]
    doc_exp = doc_emb_out.unsqueeze(1)  # [batch, 1, 300]
    neg_dist_author = torch.sum((doc_exp - neg_author_emb) ** 2, dim=2)  # [batch, k]

    neg_style_exp = neg_style_feats  # [batch, k, 300]
    neg_dist_style = torch.sum((doc_exp - neg_style_exp) ** 2, dim=2)  # [batch, k]

    # Triplet-like loss: positive should be < negative
    author_loss = torch.relu(pos_dist_author.unsqueeze(1) - neg_dist_author + margin).mean()
    style_loss = torch.relu(pos_dist_style.unsqueeze(1) - neg_dist_style + margin).mean()

    total_loss = (1 - alpha) * author_loss + alpha * style_loss

    return total_loss, author_loss, style_loss

In [None]:
# # ============================================================================
# # VADES MODEL (COSINE-BASED VERSION)
# # ============================================================================

# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class VADES(nn.Module):
#     def __init__(self, emb_dim=768, style_dim=300, hidden=512, n_authors=10):
#         super().__init__()

#         # Document encoder
#         self.encoder = nn.Sequential(
#             nn.Linear(emb_dim, hidden),
#             nn.ReLU(),
#             nn.Dropout(0.5),      # Stronger regularization for small datasets
#             nn.Linear(hidden, style_dim)
#         )

#         # Author embeddings
#         self.authors = nn.Embedding(n_authors, style_dim)
#         nn.init.xavier_uniform_(self.authors.weight)

#     def forward(self, doc_emb):
#         return self.encoder(doc_emb)


# # ============================================================================
# # COSINE CONTRASTIVE LOSS
# # ============================================================================

# def cosine_contrastive_loss(model, doc_emb, author_ids, style_feats,
#                             neg_author_ids, neg_style_feats, alpha=0.1, margin=0.5):
#     """
#     Contrastive loss based on cosine similarity.
#     Encourages doc embeddings to align with their author & style,
#     and diverge from negatives.
#     """

#     # Encode documents
#     doc_emb_out = model(doc_emb)  # [batch, 300]
#     author_emb = model.authors(author_ids)  # [batch, 300]

#     batch_size = doc_emb.size(0)
#     k_neg = neg_author_ids.size(1)

#     # Negative author embeddings
#     neg_author_emb = model.authors(neg_author_ids.view(-1)).view(batch_size, k_neg, -1)

#     # Normalize all embeddings (critical for cosine similarity)
#     doc_emb_out = F.normalize(doc_emb_out, p=2, dim=1)
#     author_emb = F.normalize(author_emb, p=2, dim=1)
#     style_feats = F.normalize(style_feats, p=2, dim=1)
#     neg_author_emb = F.normalize(neg_author_emb, p=2, dim=2)
#     neg_style_feats = F.normalize(neg_style_feats, p=2, dim=2)

#     # ---- Positive cosine distances (want them small) ----
#     pos_dist_author = 1 - F.cosine_similarity(doc_emb_out, author_emb)  # [batch]
#     pos_dist_style = 1 - F.cosine_similarity(doc_emb_out, style_feats)  # [batch]

#     # ---- Negative cosine distances (want them large) ----
#     doc_exp = doc_emb_out.unsqueeze(1)  # [batch, 1, 300]
#     # Cosine similarity via batched matrix multiplication
#     neg_sim_author = torch.bmm(neg_author_emb, doc_exp.transpose(1, 2)).squeeze(2)  # [batch, k]
#     neg_sim_style = torch.bmm(neg_style_feats, doc_exp.transpose(1, 2)).squeeze(2)  # [batch, k]

#     neg_dist_author = 1 - neg_sim_author
#     neg_dist_style = 1 - neg_sim_style

#     # ---- Triplet-like margin loss ----
#     author_loss = torch.relu(pos_dist_author.unsqueeze(1) - neg_dist_author + margin).mean()
#     style_loss = torch.relu(pos_dist_style.unsqueeze(1) - neg_dist_style + margin).mean()

#     total_loss = (1 - alpha) * author_loss + alpha * style_loss

#     return total_loss, author_loss, style_loss


In [None]:
# ============================================================================
# DATASET
# ============================================================================

class SimpleDataset(Dataset):
    def __init__(self, embs, feats, aids):
        self.embs = torch.FloatTensor(embs)
        self.feats = torch.FloatTensor(feats)
        self.aids = torch.LongTensor(aids)
    def __len__(self): return len(self.embs)
    def __getitem__(self, i): return self.embs[i], self.feats[i], self.aids[i]

author2id = {a: i for i, a in enumerate(sorted(train_df['author'].unique()))}
num_authors = len(author2id)

train_aids = np.array([author2id[a] for a in train_df['author']])
val_aids = np.array([author2id.get(a,-1) for a in val_df['author']])
test_aids = np.array([author2id.get(a,-1) for a in test_df['author']])

val_mask, test_mask = val_aids>=0, test_aids>=0
val_embs, val_feats, val_aids = val_embs[val_mask], val_feats[val_mask], val_aids[val_mask]
test_embs, test_feats, test_aids = test_embs[test_mask], test_feats[test_mask], test_aids[test_mask]

BATCH = 32
train_loader = DataLoader(SimpleDataset(train_embs, train_feats, train_aids), BATCH, shuffle=True, drop_last=True)
val_loader = DataLoader(SimpleDataset(val_embs, val_feats, val_aids), BATCH, shuffle=False)
test_loader = DataLoader(SimpleDataset(test_embs, test_feats, test_aids), BATCH, shuffle=False)

print(f"Authors: {num_authors}, Train batches: {len(train_loader)}")

Authors: 10, Train batches: 337


In [None]:
# ============================================================================
# TRAINING
# ============================================================================

model = VADES(n_authors=num_authors).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-7)

EPOCHS, ALPHA, MARGIN, K_NEG = 15, 0.1, 0.5, 5

print(f"\n{'='*80}\nTRAINING (Simplified)\n{'='*80}")
print(f"Epochs: {EPOCHS}, Alpha: {ALPHA}, Margin: {MARGIN}, K_neg: {K_NEG}\n")

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    model.train()
    losses = []

    for doc_emb, style_feat, aid in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        doc_emb, style_feat, aid = doc_emb.to(device), style_feat.to(device), aid.to(device)

        bs = doc_emb.size(0)

        # Sample negatives
        neg_aids = torch.randint(0, num_authors, (bs, K_NEG), device=device)
        for i in range(bs):
            for k in range(K_NEG):
                while neg_aids[i,k] == aid[i]:
                    neg_aids[i,k] = torch.randint(0, num_authors, (1,), device=device)

        idx = torch.randperm(bs, device=device)
        neg_style = torch.stack([style_feat[torch.roll(idx,k+1)] for k in range(K_NEG)], dim=1)

        optimizer.zero_grad()
        loss, a_loss, s_loss = simple_contrastive_loss(model, doc_emb, aid, style_feat, neg_aids, neg_style, ALPHA, MARGIN)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        if epoch == 0 and len(losses) == 1:
            print(f"First batch - Loss: {loss.item():.4f}, Author: {a_loss.item():.4f}, Style: {s_loss.item():.4f}")

    # Val
    model.eval()
    val_losses = []
    with torch.no_grad():
        for doc_emb, style_feat, aid in val_loader:
            doc_emb, style_feat, aid = doc_emb.to(device), style_feat.to(device), aid.to(device)
            doc_out = model(doc_emb)
            author_emb = model.authors(aid)
            val_loss = torch.sum((doc_out - author_emb)**2, dim=1).mean()
            val_losses.append(val_loss.item())

    avg_train, avg_val = np.mean(losses), np.mean(val_losses)
    print(f"Epoch {epoch+1}: Train={avg_train:.4f}, Val={avg_val:.4f}")

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        torch.save(model.state_dict(), os.path.join(CKPT_DIR, 'best.pth'))
        print(f"  ✅ Saved")


TRAINING (Simplified)
Epochs: 15, Alpha: 0.1, Margin: 0.5, K_neg: 5



Epoch 1:   3%|▎         | 10/337 [00:00<00:16, 19.54it/s]

First batch - Loss: 1916.1473, Author: 0.4774, Style: 19157.1758


Epoch 1: 100%|██████████| 337/337 [00:04<00:00, 73.97it/s]


Epoch 1: Train=4666.5006, Val=2.2503
  ✅ Saved


Epoch 2: 100%|██████████| 337/337 [00:03<00:00, 106.57it/s]


Epoch 2: Train=4754.1899, Val=2.2494
  ✅ Saved


Epoch 3: 100%|██████████| 337/337 [00:03<00:00, 106.21it/s]


Epoch 3: Train=4704.1279, Val=2.2486
  ✅ Saved


Epoch 4: 100%|██████████| 337/337 [00:03<00:00, 85.15it/s]


Epoch 4: Train=4783.5373, Val=2.2478
  ✅ Saved


Epoch 5: 100%|██████████| 337/337 [00:03<00:00, 93.86it/s] 


Epoch 5: Train=4716.6498, Val=2.2470
  ✅ Saved


Epoch 6: 100%|██████████| 337/337 [00:03<00:00, 107.03it/s]


Epoch 6: Train=4735.9640, Val=2.2463
  ✅ Saved


Epoch 7: 100%|██████████| 337/337 [00:03<00:00, 87.13it/s] 


Epoch 7: Train=4722.0411, Val=2.2455
  ✅ Saved


Epoch 8: 100%|██████████| 337/337 [00:03<00:00, 85.04it/s]


Epoch 8: Train=4677.5039, Val=2.2448
  ✅ Saved


Epoch 9: 100%|██████████| 337/337 [00:03<00:00, 105.31it/s]


Epoch 9: Train=4665.0687, Val=2.2441
  ✅ Saved


Epoch 10: 100%|██████████| 337/337 [00:03<00:00, 106.68it/s]


Epoch 10: Train=4736.3060, Val=2.2434
  ✅ Saved


Epoch 11: 100%|██████████| 337/337 [00:03<00:00, 106.17it/s]


Epoch 11: Train=4757.5253, Val=2.2427
  ✅ Saved


Epoch 12: 100%|██████████| 337/337 [00:03<00:00, 86.46it/s] 


Epoch 12: Train=4740.0651, Val=2.2420
  ✅ Saved


Epoch 13: 100%|██████████| 337/337 [00:03<00:00, 105.91it/s]


Epoch 13: Train=4703.7472, Val=2.2414
  ✅ Saved


Epoch 14: 100%|██████████| 337/337 [00:03<00:00, 106.12it/s]


Epoch 14: Train=4734.2384, Val=2.2408
  ✅ Saved


Epoch 15: 100%|██████████| 337/337 [00:03<00:00, 99.10it/s]


Epoch 15: Train=4669.3135, Val=2.2402
  ✅ Saved


In [None]:
# ============================================================================
# EVALUATION
# ============================================================================

model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'best.pth')))
model.eval()

# Authorship
author_means = defaultdict(list)
with torch.no_grad():
    for doc_emb, _, aid in train_loader:
        doc_emb = doc_emb.to(device)
        doc_out = model(doc_emb)
        for i, a in enumerate(aid.numpy()):
            author_means[a].append(doc_out[i].cpu().numpy())

author_means = {a: np.mean(embs, axis=0) for a, embs in author_means.items()}

correct, total = 0, 0
with torch.no_grad():
    for doc_emb, _, aid in test_loader:
        doc_emb = doc_emb.to(device)
        doc_out = model(doc_emb).cpu().numpy()
        for i, true_a in enumerate(aid.numpy()):
            dists = {a: np.linalg.norm(doc_out[i] - m) for a, m in author_means.items()}
            if min(dists, key=dists.get) == true_a: correct += 1
            total += 1

print(f"\n{'='*80}\nRESULTS\n{'='*80}")
print(f"Accuracy: {100*correct/total:.2f}% ({correct}/{total})")
print(f"Random: {100/num_authors:.2f}%")
print(f"\n✅ DONE")


RESULTS
Accuracy: 39.22% (759/1935)
Random: 10.00%

✅ DONE
