# Enhancing Movie Recommendations with Knowledge Graph Embeddings and Neural Collaborative Filtering

This notebook builds a Knowledge Graph from IMDb TSV files stored in Google Drive.

**Entities:**
- Movies (tconst)
- Persons (nconst)
- Genres (from title.basics genres column)

**Relations:**
- (movie) --HAS_GENRE--> (genre)
- (person) --DIRECTED--> (movie)
- (person) --WROTE--> (movie)
- (person) --ACTED_IN--> (movie)

In [None]:
# Install required packages (if not already installed)
!pip install pandas tqdm -q


In [None]:
# Configuration
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import os
import random

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

# Configuration
DRIVE_ROOT = '/content/drive/MyDrive'
INPUT_DIR = f'{DRIVE_ROOT}/imdb_data'  # Folder containing TSV files
OUTPUT_DIR = f'{DRIVE_ROOT}/kg_output'  # Output folder for triples and maps

# Dry-run mode: set to None for full processing, or an integer to limit rows per file
LIMIT_ROWS = None  # Change to 10000 for quick testing

# File names
FILES = {
    'basics': 'title.basics.tsv',
    'crew': 'title.crew.tsv',
    'principals': 'title.principals.tsv',
    'ratings': 'title.ratings.tsv'  # Not used for triples but can be read
}

print("Configuration loaded:")
print(f"  Input directory: {INPUT_DIR}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Limit rows: {LIMIT_ROWS if LIMIT_ROWS else 'None (full processing)'}")


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory ready: {OUTPUT_DIR}")


In [None]:
# Helper functions

def safe_split(value, sep=','):
    """Safely split a value, handling missing values marked as \\N"""
    if pd.isna(value) or value == '\\N' or value == '':
        return []
    return [v.strip() for v in str(value).split(sep) if v.strip() and v.strip() != '\\N']

def read_tsv_chunked(filepath, limit=None, chunksize=100000):
    """Read TSV file with optional row limit and chunking"""
    if limit:
        # For dry-run, read only first N rows
        df = pd.read_csv(filepath, sep='\t', nrows=limit, low_memory=False)
        return [df]  # Return as single chunk
    else:
        # Full processing with chunking
        return pd.read_csv(filepath, sep='\t', chunksize=chunksize, low_memory=False)

def get_file_path(filename):
    """Get full path for a file in INPUT_DIR"""
    return os.path.join(INPUT_DIR, filename)

print("Helper functions defined")


In [None]:
# Step 1: Extract entities and build triples from title.basics.tsv
# Entities: movies (tconst), genres
# Relations: (movie) --HAS_GENRE--> (genre)

print("=" * 60)
print("Step 1: Processing title.basics.tsv")
print("=" * 60)

triples = []
movies = set()
genres = set()

basics_path = get_file_path(FILES['basics'])
print(f"Reading: {basics_path}")

chunks = read_tsv_chunked(basics_path, limit=LIMIT_ROWS)
total_chunks = len(chunks) if isinstance(chunks, list) else None

for chunk_idx, chunk in enumerate(chunks):
    print(f"Processing chunk {chunk_idx + 1}... (rows: {len(chunk)})")
    
    # Filter out rows with missing tconst
    chunk = chunk[chunk['tconst'].notna() & (chunk['tconst'] != '\\N')]
    
    for _, row in tqdm(chunk.iterrows(), total=len(chunk), desc=f"Chunk {chunk_idx + 1}"):
        movie_id = str(row['tconst']).strip()
        if not movie_id or movie_id == '\\N':
            continue
            
        movies.add(movie_id)
        
        # Extract genres
        genre_list = safe_split(row.get('genres', ''))
        for genre in genre_list:
            if genre:
                genres.add(genre)
                triples.append(('HAS_GENRE', movie_id, genre))

print(f"\nStep 1 Complete:")
print(f"  Movies found: {len(movies)}")
print(f"  Genres found: {len(genres)}")
print(f"  HAS_GENRE triples: {sum(1 for t in triples if t[0] == 'HAS_GENRE')}")


In [None]:
# Step 2: Extract relations from title.crew.tsv
# Relations: (person) --DIRECTED--> (movie), (person) --WROTE--> (movie)

print("=" * 60)
print("Step 2: Processing title.crew.tsv")
print("=" * 60)

persons = set()
crew_path = get_file_path(FILES['crew'])
print(f"Reading: {crew_path}")

chunks = read_tsv_chunked(crew_path, limit=LIMIT_ROWS)
directed_count = 0
wrote_count = 0

for chunk_idx, chunk in enumerate(chunks):
    print(f"Processing chunk {chunk_idx + 1}... (rows: {len(chunk)})")
    
    # Filter out rows with missing tconst
    chunk = chunk[chunk['tconst'].notna() & (chunk['tconst'] != '\\N')]
    
    for _, row in tqdm(chunk.iterrows(), total=len(chunk), desc=f"Chunk {chunk_idx + 1}"):
        movie_id = str(row['tconst']).strip()
        if not movie_id or movie_id == '\\N':
            continue
        
        # Extract directors
        directors = safe_split(row.get('directors', ''))
        for director_id in directors:
            if director_id:
                persons.add(director_id)
                triples.append(('DIRECTED', director_id, movie_id))
                directed_count += 1
        
        # Extract writers
        writers = safe_split(row.get('writers', ''))
        for writer_id in writers:
            if writer_id:
                persons.add(writer_id)
                triples.append(('WROTE', writer_id, movie_id))
                wrote_count += 1

print(f"\nStep 2 Complete:")
print(f"  Persons found so far: {len(persons)}")
print(f"  DIRECTED triples: {directed_count}")
print(f"  WROTE triples: {wrote_count}")


In [None]:
# Step 3: Extract relations from title.principals.tsv
# Relations: (person) --ACTED_IN--> (movie) where category in {"actor", "actress"}

print("=" * 60)
print("Step 3: Processing title.principals.tsv")
print("=" * 60)

principals_path = get_file_path(FILES['principals'])
print(f"Reading: {principals_path}")

chunks = read_tsv_chunked(principals_path, limit=LIMIT_ROWS)
acted_count = 0

# Valid acting categories
ACTING_CATEGORIES = {'actor', 'actress'}

for chunk_idx, chunk in enumerate(chunks):
    print(f"Processing chunk {chunk_idx + 1}... (rows: {len(chunk)})")
    
    # Filter for acting roles only
    chunk = chunk[
        chunk['tconst'].notna() & 
        (chunk['tconst'] != '\\N') &
        chunk['nconst'].notna() & 
        (chunk['nconst'] != '\\N') &
        chunk['category'].notna()
    ]
    
    # Filter by category
    chunk = chunk[chunk['category'].str.lower().isin(ACTING_CATEGORIES)]
    
    for _, row in tqdm(chunk.iterrows(), total=len(chunk), desc=f"Chunk {chunk_idx + 1}"):
        movie_id = str(row['tconst']).strip()
        person_id = str(row['nconst']).strip()
        
        if movie_id and person_id and movie_id != '\\N' and person_id != '\\N':
            persons.add(person_id)
            triples.append(('ACTED_IN', person_id, movie_id))
            acted_count += 1

print(f"\nStep 3 Complete:")
print(f"  Total persons: {len(persons)}")
print(f"  ACTED_IN triples: {acted_count}")


In [None]:
# Step 4: Create entity and relation mappings
# Map string entities/relations to integer IDs

print("=" * 60)
print("Step 4: Creating entity and relation mappings")
print("=" * 60)

# Collect all entities
all_entities = set()
all_entities.update(movies)
all_entities.update(persons)
all_entities.update(genres)

# Create entity mapping (string -> integer ID)
entity_map = {entity: idx for idx, entity in enumerate(sorted(all_entities), start=1)}
# Reserve 0 for padding if needed (optional)

# Create relation mapping
relations = {'HAS_GENRE', 'DIRECTED', 'WROTE', 'ACTED_IN'}
relation_map = {rel: idx for idx, rel in enumerate(sorted(relations), start=1)}

print(f"Entity mapping created: {len(entity_map)} entities")
print(f"Relation mapping created: {len(relation_map)} relations")
print(f"\nRelations: {sorted(relations)}")


In [None]:
# Step 5: Convert triples to integer IDs and create final dataset

print("=" * 60)
print("Step 5: Converting triples to integer IDs")
print("=" * 60)

# Filter triples to only include entities that exist in our entity map
# (should be all, but safety check)
valid_triples = []
missing_entities = set()

for rel, head, tail in tqdm(triples, desc="Converting triples"):
    if head in entity_map and tail in entity_map and rel in relation_map:
        valid_triples.append({
            'head': entity_map[head],
            'relation': relation_map[rel],
            'tail': entity_map[tail],
            'head_str': head,
            'relation_str': rel,
            'tail_str': tail
        })
    else:
        if head not in entity_map:
            missing_entities.add(head)
        if tail not in entity_map:
            missing_entities.add(tail)

if missing_entities:
    print(f"Warning: {len(missing_entities)} entities not found in mapping (should not happen)")

print(f"Valid triples: {len(valid_triples)}")
print(f"Original triples: {len(triples)}")
print(f"Filtered out: {len(triples) - len(valid_triples)}")


In [None]:
# Step 6: Save outputs to Google Drive

print("=" * 60)
print("Step 6: Saving outputs")
print("=" * 60)

# Save triples.csv (with integer IDs)
triples_df = pd.DataFrame(valid_triples)
triples_df[['head', 'relation', 'tail']].to_csv(
    os.path.join(OUTPUT_DIR, 'triples.csv'),
    index=False
)
print(f"Saved: {OUTPUT_DIR}/triples.csv")
print(f"  Shape: {triples_df.shape}")

# Save entity_map.csv
entity_map_df = pd.DataFrame([
    {'entity_id': idx, 'entity': entity}
    for entity, idx in sorted(entity_map.items(), key=lambda x: x[1])
])
entity_map_df.to_csv(
    os.path.join(OUTPUT_DIR, 'entity_map.csv'),
    index=False
)
print(f"Saved: {OUTPUT_DIR}/entity_map.csv")
print(f"  Shape: {entity_map_df.shape}")

# Save relation_map.csv
relation_map_df = pd.DataFrame([
    {'relation_id': idx, 'relation': rel}
    for rel, idx in sorted(relation_map.items(), key=lambda x: x[1])
])
relation_map_df.to_csv(
    os.path.join(OUTPUT_DIR, 'relation_map.csv'),
    index=False
)
print(f"Saved: {OUTPUT_DIR}/relation_map.csv")
print(f"  Shape: {relation_map_df.shape}")

print("\nAll outputs saved successfully!")


In [None]:
# Step 7: Print statistics

print("=" * 60)
print("KNOWLEDGE GRAPH STATISTICS")
print("=" * 60)

print(f"\nEntities:")
print(f"  Movies: {len(movies)}")
print(f"  Persons: {len(persons)}")
print(f"  Genres: {len(genres)}")
print(f"  Total entities: {len(all_entities)}")

print(f"\nRelations:")
for rel in sorted(relations):
    count = sum(1 for t in valid_triples if t['relation_str'] == rel)
    print(f"  {rel}: {count:,} triples")

print(f"\nTotal triples: {len(valid_triples):,}")

print(f"\nOutput files:")
print(f"  {OUTPUT_DIR}/triples.csv")
print(f"  {OUTPUT_DIR}/entity_map.csv")
print(f"  {OUTPUT_DIR}/relation_map.csv")

print("\n" + "=" * 60)
print("Pipeline complete!")
print("=" * 60)
