In [None]:
import json
import re
from collections import defaultdict
from tqdm import tqdm
import json
import re
from collections import defaultdict
from tqdm import tqdm

def clean_message(text):
    text = text.lower()
    text = re.sub(r'https?://\S+|www\.\S+', '', text) 
    text = re.sub(r'[^a-z0-9\s]', '', text)            
    text = re.sub(r'\s+', ' ', text).strip()           
    return text

def add_history_to_jsonl(input_path, output_path, k=5):
    with open(input_path, 'r') as f_in, open(output_path, 'w') as f_out:
        for line in tqdm(f_in, desc="Processing games"):
            if not line.strip():
                continue

            game = json.loads(line)
            game_id = game.get('game_id', 'UNKNOWN')

            messages = game['messages']
            speakers = game['speakers']
            receivers = game['receivers']

            history_lookup = defaultdict(list)

            game['history'] = []

            for i in range(len(messages)):
                sender = speakers[i]
                receiver = receivers[i]
                pair_key = (game_id, sender, receiver)
                cleaned = clean_message(messages[i])
                game['messages'][i] = cleaned
                last_k = history_lookup[pair_key][-k:]
                game['history'].append(last_k)
                history_lookup[pair_key].append(cleaned)
            f_out.write(json.dumps(game) + "\n")


add_history_to_jsonl("data/train.jsonl", "train_with_history_10.jsonl", k=10)


Processing games: 189it [00:00, 688.42it/s]


In [3]:
add_history_to_jsonl("data/validation.jsonl", "val_with_history_10.jsonl", k=10)
add_history_to_jsonl("data/test.jsonl", "test_with_history_10.jsonl", k=10)

Processing games: 21it [00:00, 575.77it/s]
Processing games: 42it [00:00, 851.04it/s]


In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
import numpy as np

import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from sklearn.metrics import f1_score, accuracy_score

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        
        pe = torch.zeros(max_len, d_model) 
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)
class GameStateEncoder(nn.Module):
    def __init__(self, season_vocab_size, year_vocab_size, out_dim=32):
        super(GameStateEncoder, self).__init__()
        input_dim = 2 + season_vocab_size + year_vocab_size
        self.fc = nn.Sequential(
            nn.Linear(input_dim, out_dim),
            nn.ReLU()
        )
    def forward(self, game_features):
        return self.fc(game_features)
class HistoryEncoderTransformer(nn.Module):
    def __init__(self, hidden_dim=768, num_layers=2, num_heads=8, dropout=0.1):
        super(HistoryEncoderTransformer, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos_encoder = PositionalEncoding(hidden_dim, dropout=dropout, max_len=20)  # max history length assumed <= 20

    def forward(self, history_texts_batch):

        device = next(self.parameters()).device
        all_bert_embeddings = []
        lengths = []

        for history in history_texts_batch:
            if len(history) == 0:
                history = [""]
            inputs = self.tokenizer(history, return_tensors="pt", padding=True, truncation=True, return_attention_mask=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.bert(**inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :]
            all_bert_embeddings.append(cls_embeddings)
            lengths.append(cls_embeddings.size(0))
        padded = nn.utils.rnn.pad_sequence(all_bert_embeddings, batch_first=True)  
        max_len = padded.size(1)
        attn_mask = torch.zeros(padded.size(0), max_len, dtype=torch.bool, device=device)
        for i, l in enumerate(lengths):
            if l < max_len:
                attn_mask[i, l:] = True
        padded = self.pos_encoder(padded)
        transformer_out = self.transformer_encoder(padded, src_key_padding_mask=attn_mask)
        pooled = []
        for i in range(transformer_out.size(0)):
            valid_tokens = transformer_out[i, :lengths[i], :]
            pooled.append(valid_tokens.mean(dim=0))
        history_encoding = torch.stack(pooled, dim=0) 
        return history_encoding
class FusionAttention(nn.Module):
    def __init__(self, input_dim, num_heads=2):
        super(FusionAttention, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)
        self.pooling_query = nn.Parameter(torch.randn(1, 1, input_dim))
    
    def forward(self, features):
        attn_out, _ = self.attn(features, features, features)
        query = self.pooling_query.expand(attn_out.size(0), -1, -1) 
        attn_weights = torch.bmm(query, attn_out.transpose(1, 2))  
        attn_weights = torch.softmax(attn_weights, dim=-1)  
        fused_vector = torch.bmm(attn_weights, attn_out).squeeze(1)  
        return fused_vector

class DeceptionClassifier(nn.Module):
    def __init__(self, fused_dim):
        super(DeceptionClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    
    def forward(self, fused_vector):
        logits = self.classifier(fused_vector)
        return logits

class MultiModalDeceptionModel(nn.Module):
    def __init__(self, season_vocab_size=3, year_vocab_size=5,
                 game_state_out_dim=32, fusion_dim=768):
        super(MultiModalDeceptionModel, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.game_encoder = GameStateEncoder(season_vocab_size, year_vocab_size, out_dim=game_state_out_dim)
        
        self.sender_embedding = nn.Embedding(num_embeddings=7, embedding_dim=16)
        self.receiver_embedding = nn.Embedding(num_embeddings=7, embedding_dim=16)
        self.sender_receiver_proj = nn.Linear(64, self.tokenizer.model_max_length if hasattr(self.tokenizer, 'model_max_length') else fusion_dim)
        self.proj_to_text_dim = nn.Linear(self.sender_receiver_proj.out_features, fusion_dim)
        
        self.history_encoder = HistoryEncoderTransformer(hidden_dim=fusion_dim)
        self.text_feature_dim = fusion_dim  
        self.fusion_attention = FusionAttention(input_dim=self.text_feature_dim, num_heads=2)

        self.classifier = DeceptionClassifier(fused_dim=self.text_feature_dim)
    
    def forward(self, current_message, game_state_features, history_texts, sender_ids, receiver_ids):
        device = next(self.parameters()).device
        inputs = self.tokenizer(current_message, return_tensors="pt", truncation=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        text_outputs = self.bert(**inputs)
        text_vector = text_outputs.last_hidden_state[:, 0, :]  
        game_vector_raw = self.game_encoder(game_state_features)  
        sender_emb = self.sender_embedding(sender_ids)  
        receiver_emb = self.receiver_embedding(receiver_ids) 
        combined_game = torch.cat([game_vector_raw, sender_emb, receiver_emb], dim=1)  
        game_vector_proj = self.sender_receiver_proj(combined_game)
        game_vector = self.proj_to_text_dim(game_vector_proj)
    
        history_vector = self.history_encoder(history_texts)  
        fusion_input = torch.stack([text_vector, game_vector, history_vector], dim=1)
  
        fused_vector = self.fusion_attention(fusion_input) 
        logits = self.classifier(fused_vector) 
        return logits


class DeceptionDataset(Dataset):
    def __init__(self, jsonl_file, season_to_idx=None, year_buckets=None, country_to_idx=None):
 
        self.samples = []
        self.season_to_idx = season_to_idx or {"Spring": 0, "Fall": 1, "Winter": 2}
        self.year_buckets = year_buckets or [1901, 1906, 1911, 1916, 1921]
        self.country_to_idx = country_to_idx or {"russia":0, "turkey":1, "england":2, "france":3, "germany":4, "italy":5, "austria":6}
        
        with open(jsonl_file, 'r') as f:
            for line in f:
                if not line.strip():
                    continue
                game = json.loads(line)
                num_messages = len(game['messages'])
                for i in range(num_messages):
                    sample = {}
                    sample['current_message'] = game['messages'][i]
                    try:
                        game_score = float(game['game_score'][i])
                    except:
                        game_score = 0.0
                    try:
                        score_delta = float(game['game_score_delta'][i])
                    except:
                        score_delta = 0.0
                    season = game['seasons'][i]
                    season_vec = [0] * len(self.season_to_idx)
                    if season in self.season_to_idx:
                        season_vec[self.season_to_idx[season]] = 1
                    year = int(game['years'][i])
                    year_bucket = self.bucket_year(year)
                    year_vec = [0] * (len(self.year_buckets))
                    year_vec[year_bucket] = 1
                    game_state = [game_score, score_delta] + season_vec + year_vec
                    sample['game_state_features'] = torch.tensor(game_state, dtype=torch.float)
                    sample['history'] = game['history'][i]
                    sender_str = game['speakers'][i].lower()
                    receiver_str = game['receivers'][i].lower()
                    sample['sender'] = torch.tensor(self.country_to_idx.get(sender_str, 0), dtype=torch.long)
                    sample['receiver'] = torch.tensor(self.country_to_idx.get(receiver_str, 0), dtype=torch.long)
                    label_raw = game['sender_labels'][i]
                    label = 0 if label_raw is True or label_raw == "true" else 1
                    sample['label'] = torch.tensor(label, dtype=torch.float)
                    
                    self.samples.append(sample)
    
    def bucket_year(self, year):
        for idx, bound in enumerate(self.year_buckets):
            if year < bound:
                return idx
        return len(self.year_buckets) - 1
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    current_messages = [sample['current_message'] for sample in batch]
    game_states = torch.stack([sample['game_state_features'] for sample in batch])  
    histories = [sample['history'] for sample in batch] 
    sender_ids = torch.stack([sample['sender'] for sample in batch])
    receiver_ids = torch.stack([sample['receiver'] for sample in batch])
    labels = torch.stack([sample['label'] for sample in batch])
    return current_messages, game_states, histories, sender_ids, receiver_ids, labels

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = MultiModalDeceptionModel(season_vocab_size=3, year_vocab_size=5,
                                 game_state_out_dim=32, fusion_dim=768)
model.to(device)
model.load_state_dict(torch.load("best_model_attention2.pt"))
model.eval()
test_dataset   = DeceptionDataset("test_with_history_10.jsonl")

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Test Evaluation"):
        current_messages, game_states, histories, sender_ids, receiver_ids, labels = batch
        
        # Move tensors to the device (GPU or CPU)
        game_states = game_states.to(device)
        sender_ids = sender_ids.to(device)
        receiver_ids = receiver_ids.to(device)
        labels = labels.to(device)
        
        logits = model(current_messages, game_states, histories, sender_ids, receiver_ids)
        prob = torch.sigmoid(logits.view(-1))
        preds = (prob > 0.5).long()
        
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.long().tolist())

# Compute and print the final confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(all_labels, all_preds)
print("Final Confusion Matrix:")

print(cm)


cuda


Test Evaluation: 100%|██████████| 172/172 [01:40<00:00,  1.71it/s]

Final Confusion Matrix:
[[2328  173]
 [ 189   51]]





In [5]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = MultiModalDeceptionModel(season_vocab_size=3, year_vocab_size=5,
                                 game_state_out_dim=32, fusion_dim=768)
model.to(device)
model.load_state_dict(torch.load("best_model_attention2.pt"))

print(model)

cuda
MultiModalDeceptionModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(all_labels, all_preds)

precision_macro = precision_score(all_labels, all_preds, average="macro")
recall_macro = recall_score(all_labels, all_preds, average="macro")


f1_macro = f1_score(all_labels, all_preds, average="macro")
f1_per_class = f1_score(all_labels, all_preds, average=None, labels=[0, 1])


print(f"Accuracy          : {accuracy:.4f}")
print(f"Precision (avg)   : {precision_macro:.4f}")
print(f"Recall    (avg)   : {recall_macro:.4f}")
print(f"F1 Score  (avg)   : {f1_macro:.4f}")
print(f"F1 Score (class 0): {f1_per_class[0]:.4f}")
print(f"F1 Score (class 1): {f1_per_class[1]:.4f}")


Accuracy       : 0.8679
Precision (avg): 0.5763
Recall    (avg): 0.5717
F1 Score  (avg): 0.5738
F1 Score (class 0): 0.9279
F1 Score (class 1): 0.2198


In [None]:
import torch
from tqdm import tqdm
from sklearn.metrics import classification_report
default_speaker_mapping = {
    "russia": 0, "turkey": 1, "england": 2,
    "france": 3, "germany": 4, "italy": 5, "austria": 6
}
inv_speaker_mapping = {v: k for k, v in default_speaker_mapping.items()}

misclassified_class_0 = []  
misclassified_class_1 = [] 
all_true = []
all_preds = []
model.eval()

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating and Collecting Misclassified Examples"):
        current_messages, game_states, histories, sender_ids, receiver_ids, labels = batch

        game_states = game_states.to(device)
        sender_ids = sender_ids.to(device)
        receiver_ids = receiver_ids.to(device)
        labels = labels.to(device)
        
        logits = model(current_messages, game_states, histories, sender_ids, receiver_ids)
        probs = torch.sigmoid(logits.view(-1))
        preds = (probs > 0.5).long()
        all_true.extend(labels.cpu().tolist())
        all_preds.extend(preds.cpu().tolist())
        for i in range(len(labels)):
            if preds[i] != labels[i]:
                sample_info = {
                    'current_message': current_messages[i],
                    'history': histories[i],
                    'true_label': labels[i].item(),
                    'predicted_label': preds[i].item(),
                    'sender': inv_speaker_mapping[sender_ids[i].item()],
                    'receiver': inv_speaker_mapping[receiver_ids[i].item()],
                }
                if labels[i].item() == 0:
                    misclassified_class_0.append(sample_info)
                else:
                    misclassified_class_1.append(sample_info)

print("Classification Report:\n")
report = classification_report(all_true, all_preds, target_names=["Class 0", "Class 1"])
print(report)

print("Misclassified Examples for Class 0 (true label 0 predicted as 1):")
for sample in misclassified_class_0[:5]: 
    print("Message        :", sample['current_message'])
    print("History        :", sample['history'])
    print("Sender         :", sample['sender'])
    print("Receiver       :", sample['receiver'])
    print("True Label     :", sample['true_label'])
    print("Predicted Label:", sample['predicted_label'])
    print("-" * 50)

print("\nMisclassified Examples for Class 1 (true label 1 predicted as 0):")
for sample in misclassified_class_1[:5]: 
    print("Message        :", sample['current_message'])
    print("History        :", sample['history'])
    print("Sender         :", sample['sender'])
    print("Receiver       :", sample['receiver'])
    print("True Label     :", sample['true_label'])
    print("Predicted Label:", sample['predicted_label'])
    print("-" * 50)


Evaluating and Collecting Misclassified Examples:   0%|          | 0/172 [00:00<?, ?it/s]

Evaluating and Collecting Misclassified Examples: 100%|██████████| 172/172 [01:46<00:00,  1.61it/s]

Classification Report:

              precision    recall  f1-score   support

     Class 0       0.92      0.93      0.93      2501
     Class 1       0.23      0.21      0.22       240

    accuracy                           0.87      2741
   macro avg       0.58      0.57      0.57      2741
weighted avg       0.86      0.87      0.87      2741

Misclassified Examples for Class 0 (true label 0 predicted as 1):
Message        : also weirdly my phone doesnt seem to be updating me with latest messages the way it should so i end up seeing these a lot later than i should
History        : ['i mean if im free to go another direction ill take it in a heartbeat and i doubt youll be very upset at my opening moves', 'his dots are closer to you', 'so now his dots are a lot closer to you than mine are im guessing youre ordering into trieste anyway to protect serbia while you take greece with bulgaria and aegean so the question is what is ion doing', 'oh come now its much more amusing to support 




In [None]:
import torch
from collections import defaultdict
from tqdm import tqdm
default_speaker_mapping = {
    "russia": 0, "turkey": 1, "england": 2,
    "france": 3, "germany": 4, "italy": 5, "austria": 6
}
inv_speaker_mapping = {v: k for k, v in default_speaker_mapping.items()}


speaker_total = defaultdict(int)              
speaker_misclassified = defaultdict(int)      
speaker_label_counts = defaultdict(lambda: defaultdict(int)) 
speaker_pred_counts = defaultdict(lambda: defaultdict(int))   
model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Analyzing Speaker Distribution"):
        current_messages, game_states, histories, sender_ids, receiver_ids, labels = batch

        game_states = game_states.to(device)
        sender_ids = sender_ids.to(device)
        receiver_ids = receiver_ids.to(device)
        labels = labels.to(device)

        logits = model(current_messages, game_states, histories, sender_ids, receiver_ids)
        probs = torch.sigmoid(logits.view(-1))
        preds = (probs > 0.5).long()

        batch_size = len(labels)
        for i in range(batch_size):
            sender = sender_ids[i].item()
            true_label = labels[i].item()
            pred_label = preds[i].item()
            
            speaker_total[sender] += 1
            speaker_label_counts[sender][true_label] += 1
            speaker_pred_counts[sender][pred_label] += 1
            
            if pred_label != true_label:
                speaker_misclassified[sender] += 1


for sid in sorted(inv_speaker_mapping.keys()):
    total = speaker_total[sid]
    misclassified = speaker_misclassified[sid]
    error_rate = misclassified / total if total > 0 else 0.0
    true_counts = speaker_label_counts[sid]
    pred_counts = speaker_pred_counts[sid]
    
    print(f"Speaker: {inv_speaker_mapping[sid].capitalize()}")
    print(f"  Total Samples       : {total}")
    print(f"  Misclassified Count : {misclassified} (Error Rate: {error_rate:.4f})")

    print("  True Label Distribution:")
    for label, count in sorted(true_counts.items()):
        print(f"    Label {label} : {count}")
    
    print("  Predicted Label Distribution:")
    for label, count in sorted(pred_counts.items()):
        print(f"    Label {label} : {count}")
    print("-" * 50)


Analyzing Speaker Distribution: 100%|██████████| 172/172 [01:40<00:00,  1.71it/s]

Speaker Distribution, True Label Distribution, Predicted Label Distribution, and Misclassification Rates:
Speaker: Russia
  Total Samples       : 852
  Misclassified Count : 109 (Error Rate: 0.1279)
  True Label Distribution:
    Label 0.0 : 790
    Label 1.0 : 62
  Predicted Label Distribution:
    Label 0 : 781
    Label 1 : 71
--------------------------------------------------
Speaker: Turkey
  Total Samples       : 441
  Misclassified Count : 57 (Error Rate: 0.1293)
  True Label Distribution:
    Label 0.0 : 401
    Label 1.0 : 40
  Predicted Label Distribution:
    Label 0 : 418
    Label 1 : 23
--------------------------------------------------
Speaker: England
  Total Samples       : 307
  Misclassified Count : 22 (Error Rate: 0.0717)
  True Label Distribution:
    Label 0.0 : 291
    Label 1.0 : 16
  Predicted Label Distribution:
    Label 0 : 299
    Label 1 : 8
--------------------------------------------------
Speaker: France
  Total Samples       : 804
  Misclassified Count


