In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install transformers datasets faiss-cpu torch-geometric networkx matplotlib

import os
import sys
import zipfile
import subprocess
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    TransfoXLTokenizer,
    TransfoXLLMHeadModel,
    Trainer,
    TrainingArguments,
    TransfoXLConfig,
    DataCollatorForLanguageModeling
)
import faiss
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.data import HeteroData

os.environ['TRUST_REMOTE_CODE'] = '1'



###Data Subset Creation

In [3]:
data_path = '/content/drive/MyDrive/GNN-LM_Project/Datasets/enwik8'


if not os.path.exists(data_path):
    raise FileNotFoundError(f"The dataset file was not found at {data_path}")

print("Reading the enwik8 dataset...")
with open(data_path, 'r', encoding='utf-8') as f:
    data = f.read()

subset_size = 1000000
subset_data = data[:subset_size]
print(f"Subset size: {len(subset_data)} characters")

Reading the enwik8 dataset...
Subset size: 1000000 characters


###Preprocessing

In [4]:
def preprocess_enwik8(data, subset_size):
    print(f'Length of enwik8_subset: {len(data)} characters')

    num_test_chars = len(data) // 10

    train_data = data[:-2 * num_test_chars]
    valid_data = data[-2 * num_test_chars:-num_test_chars]
    test_data = data[-num_test_chars:]

    def save_split(filename, split_data):
        print(f'{filename} will have {len(split_data)} characters')
        print('- Tokenizing...')
        part_str = ' '.join([str(ord(c)) if c != '\n' else '\n' for c in split_data])
        print('- Writing...')
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(part_str)
        with open(filename + '.raw', 'wb') as f:
            f.write(split_data.encode('utf-8'))

    save_split('train.txt', train_data)
    save_split('valid.txt', valid_data)
    save_split('test.txt', test_data)

    print("Preprocessing completed: Data split into train.txt, valid.txt, and test.txt.")

preprocess_enwik8(subset_data, subset_size)

Length of enwik8_subset: 1000000 characters
train.txt will have 800000 characters
- Tokenizing...
- Writing...
valid.txt will have 100000 characters
- Tokenizing...
- Writing...
test.txt will have 100000 characters
- Tokenizing...
- Writing...
Preprocessing completed: Data split into train.txt, valid.txt, and test.txt.


###Load and Tokenize the Dataset

In [5]:
!pip install sacremoses



In [12]:
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')


# Add a pad token if it's missing
if tokenizer.pad_token is None:
  print("Adding [PAD] token to tokenizer...")
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})

`TransfoXL` was deprecated due to security issues linked to `pickle.load` in `TransfoXLTokenizer`. See more details on this model's documentation page: `https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/transfo-xl.md`.


Adding [PAD] token to tokenizer...


In [13]:
print("Loading the dataset...")
# Load the text files into a HuggingFace dataset
dataset = load_dataset(
    'text',
    data_files={
        'train': 'train.txt',
        'validation': 'valid.txt',
        'test': 'test.txt'
    }
)

# Inspect the dataset structure to ensure the column names are correct
print("Dataset columns:")
print(dataset['train'].column_names)

# Define a tokenization function
def tokenize_function(examples):
    # Ensure the column name 'text' is correct
    return tokenizer(examples['text'], truncation=True, padding=False, max_length=512)

print("Tokenizing the dataset...")
# Tokenize the dataset and remove the original 'text' column
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['text'])

# Debug the tokenized dataset structure
print("Tokenized dataset structure:")
print(tokenized_datasets)

Loading the dataset...
Dataset columns:
['text']
Tokenizing the dataset...


Map:   0%|          | 0/8400 [00:00<?, ? examples/s]

Map:   0%|          | 0/1842 [00:00<?, ? examples/s]

Map:   0%|          | 0/1421 [00:00<?, ? examples/s]

Tokenized dataset structure:
DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 8400
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 1842
    })
    test: Dataset({
        features: ['input_ids'],
        num_rows: 1421
    })
})


###Group Texts into Blocks

In [14]:
block_size = 128

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples['input_ids'])

    total_length = (total_length // block_size) * block_size

    result = {
        k: [concatenated_examples[k][i:i + block_size] for i in range(0, total_length, block_size)]
        for k in concatenated_examples.keys()
    }
    return result

print("Grouping texts into blocks...")
lm_datasets = tokenized_datasets.map(group_texts, batched=True)

Grouping texts into blocks...


Map:   0%|          | 0/8400 [00:00<?, ? examples/s]

Map:   0%|          | 0/1842 [00:00<?, ? examples/s]

Map:   0%|          | 0/1421 [00:00<?, ? examples/s]

###Prepare Data Collator

In [15]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

###Model Architecture Updates and Training Setup

In [18]:
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103', torch_dtype=torch.float16)

# Add the special tokens to the model
model.resize_token_embeddings(len(tokenizer))

# Define training arguments
training_args = TrainingArguments(
    output_dir='/content/drive/MyDrive/GNN-LM_Project/results',
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy='steps',
    eval_steps=500,
    save_steps=1000,
    logging_steps=100,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir='/content/drive/MyDrive/GNN-LM_Project/logs',
    fp16=True,  # Enable mixed precision
    # Add this line to ensure the model is loaded in fp16
    fp16_full_eval=True
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets['train'],
    eval_dataset=lm_datasets['validation'],
    data_collator=data_collator,
)

# Train the model
print("Starting training...")
trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Starting training...


TypeError: type_as() missing 1 required positional arguments: "other"

In [None]:
print("Saving the trained model...")
trainer.save_model('/content/drive/MyDrive/GNN-LM_Project/trained_transformer-xl')

###Compute Token Representations

In [None]:
print("Loading the trained model for representation extraction...")
trained_model = TransfoXLLMHeadModel.from_pretrained('./trained_transformer-xl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model.to(device)
trained_model.eval()

def extract_representations(model, dataset, batch_size=32):
    token_reps = []
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model.transformer(input_ids, attention_mask=attention_mask, return_dict=True)
            last_hidden_states = outputs.last_hidden_state

            last_hidden_states = last_hidden_states.cpu().numpy()
            token_reps.append(last_hidden_states)

    token_reps = np.concatenate(token_reps, axis=0)
    return token_reps

print("Extracting token representations for the training set...")
train_reps = extract_representations(trained_model, lm_datasets['train'], batch_size=32)
print(f'Train token representations shape: {train_reps.shape}')


###kNN Retrieval with FAISS

In [None]:
# Parameters for FAISS
k = 1024
q = 128
cluster_count = 4096

d = train_reps.shape[1]

# Initialize FAISS index with Product Quantization (PQ)
nlist = cluster_count  # No of Voronoi cells
m = q // 8  # No of bytes per subvector

print("Initializing FAISS index...")
quantizer = faiss.IndexFlatL2(d)  # Using L2 distance
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)  # 8 bits per subvector

print("Training FAISS index...")
index.train(train_reps)
print("FAISS index trained.")

print("Adding vectors to FAISS index...")
index.add(train_reps)
print("Vectors added to FAISS index.")

In [None]:
# Assign sequence IDs to each token
# Each block is considered a separate sequence
num_sequences = len(lm_datasets['train'])
block_size_actual = block_size

print("Assigning sequence IDs to each token...")
sequence_ids = np.repeat(np.arange(num_sequences), block_size_actual)
print(f"Total tokens: {len(sequence_ids)}")
print(f"Shape of sequence_ids: {sequence_ids.shape}")

assert len(sequence_ids) == train_reps.shape[0], "Sequence IDs length mismatch with token representations."

###Neighbor Context Retrieval

In [None]:
# Perform kNN search while avoiding data leakage
def knn_search_no_leakage(index, query_vectors, query_ids, k):
    D, I = index.search(query_vectors, k + 10)
    filtered_I = []
    filtered_D = []

    for i in range(len(query_vectors)):
        neighbors = I[i]
        distances = D[i]
        current_seq_id = query_ids[i]
        valid = []
        valid_dist = []
        for idx, dist in zip(neighbors, distances):
            if sequence_ids[idx] != current_seq_id:
                valid.append(idx)
                valid_dist.append(dist)
            if len(valid) == k:
                break
        filtered_I.append(valid)
        filtered_D.append(valid_dist)

    return np.array(filtered_D), np.array(filtered_I)

print("Performing kNN search while avoiding data leakage...")
D, I = knn_search_no_leakage(index, train_reps, sequence_ids, k)
print(f'kNN search completed. Distance shape: {D.shape}, Indices shape: {I.shape}')

###Neighbor Context Retrieval with Window Size l=1, r=1

In [None]:
print("Preparing sequences for neighbor context retrieval...")
sequences = lm_datasets['train']['input_ids']
sequences = [seq.tolist() for seq in sequences]

def retrieve_neighbor_contexts(neighbor_indices, sequences, l=1, r=1):
    neighbor_contexts = []
    for i, neighbors in enumerate(neighbor_indices):
        contexts = []
        for neighbor_idx in neighbors:
            seq_id = sequence_ids[neighbor_idx]
            pos_in_seq = neighbor_idx % block_size
            seq = sequences[seq_id]

            left = seq[max(0, pos_in_seq - l): pos_in_seq]
            right = seq[pos_in_seq + 1: pos_in_seq + 1 + r]
            token = seq[pos_in_seq]
            context = left + [token] + right
            contexts.append(context)
        neighbor_contexts.append(contexts)
    return neighbor_contexts

print("Retrieving neighbor contexts with window size l=1, r=1...")
neighbor_contexts = retrieve_neighbor_contexts(I, sequences, l=1, r=1)
print("Neighbor contexts retrieved.")


###Graph Construction for Token Representations

In [None]:
print("Constructing the heterogeneous graph...")

graph = HeteroData()

graph['ao'].x = torch.tensor(train_reps, dtype=torch.float)  # Original tokens
graph['an'].x = torch.tensor(train_reps, dtype=torch.float)  # Neighbor tokens (using same representations for simplicity)

# Define edge types
# rintra: Intra-context connections within 'ao'
# rinter: Inter-context connections between 'an' and 'ao'

# Collect edges for rintra (within 'ao')
rintra_src = []
rintra_dst = []

for seq_id, seq in enumerate(sequences):
    for i in range(len(seq) - 1):
        src_idx = seq_id * block_size + i
        dst_idx = seq_id * block_size + i + 1
        rintra_src.append(src_idx)
        rintra_dst.append(dst_idx)

# Add rintra edges (ao -> ao)
graph['ao', 'rintra', 'ao'].edge_index = torch.tensor([rintra_src, rintra_dst], dtype=torch.long)

# Collect edges for rinter (from 'an' to 'ao')
rinter_src = []
rinter_dst = []

for i, neighbors in enumerate(I):
    for neighbor_idx in neighbors:
        rinter_src.append(neighbor_idx)
        rinter_dst.append(i)

# Add rinter edges (an -> ao)
graph['an', 'rinter', 'ao'].edge_index = torch.tensor([rinter_src, rinter_dst], dtype=torch.long)

print("Heterogeneous graph constructed.")

###Graph Visualization

In [None]:
import random

def visualize_graph(graph, num_nodes=1000):
    G = nx.Graph()

    sampled_nodes_ao = random.sample(range(graph['ao'].x.size(0)), min(num_nodes, graph['ao'].x.size(0)))
    sampled_nodes_an = random.sample(range(graph['an'].x.size(0)), min(num_nodes, graph['an'].x.size(0)))

    for node in sampled_nodes_ao:
        G.add_node(f'ao_{node}', type='ao')
    for node in sampled_nodes_an:
        G.add_node(f'an_{node}', type='an')

    # Add rintra edges (ao -> ao)
    rintra_edges = graph['ao', 'rintra', 'ao'].edge_index.numpy().T
    rintra_edges = rintra_edges[np.isin(rintra_edges[:,0], sampled_nodes_ao) & np.isin(rintra_edges[:,1], sampled_nodes_ao)]
    for src, dst in rintra_edges:
        G.add_edge(f'ao_{src}', f'ao_{dst}', type='rintra')

    # Add rinter edges (an -> ao)
    rinter_edges = graph['an', 'rinter', 'ao'].edge_index.numpy().T
    rinter_edges = rinter_edges[np.isin(rinter_edges[:,0], sampled_nodes_an) & np.isin(rinter_edges[:,1], sampled_nodes_ao)]
    for src, dst in rinter_edges:
        G.add_edge(f'an_{src}', f'ao_{dst}', type='rinter')

    color_map = []
    for node in G:
        if G.nodes[node]['type'] == 'ao':
            color_map.append('blue')
        else:
            color_map.append('green')

    edge_colors = []
    for u, v, attrs in G.edges(data=True):
        if attrs['type'] == 'rintra':
            edge_colors.append('black')
        else:
            edge_colors.append('red')

    plt.figure(figsize=(12, 12))
    pos = nx.spring_layout(G, k=0.1)
    nx.draw_networkx_nodes(G, pos, node_size=20, node_color=color_map, alpha=0.7)
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, alpha=0.5)
    plt.title('Sampled Heterogeneous Graph')
    plt.axis('off')
    plt.show()

print("Visualizing a sampled portion of the heterogeneous graph...")
visualize_graph(graph, num_nodes=500)