In [1]:
# general purpose python
import collections
import datetime
import glob
import importlib
import itertools
import json
import math
import os
import pickle
import random
import re
import shutil
import sys
import time
import warnings

# general purpose data science
import IPython
import ipywidgets as ipw
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
from plotly.offline import download_plotlyjs
import pylab
import scipy
import seaborn as sns
import sklearn
from sklearn import *
import statsmodels as sm

# computer vision
import cv2
import imageio
import PIL
from PIL import *

# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

# geospatial
import rasterio as rio
import rasterio.features

warnings.filterwarnings('ignore')

np.random.seed(1337)

mpl.rcParams['figure.dpi'] = 400

IPython.core.display.display(IPython.core.display.HTML("<style>.container { width:100% !important; }</style>"))

pd.options.display.max_colwidth = 32
pd.options.display.float_format = '{:,.6f}'.format
pd.options.display.expand_frame_repr = False

%matplotlib inline

sns.set(font_scale=1.3)
sns.set_style('whitegrid')
sns.set_palette(sns.color_palette('muted'))

plotly.offline.init_notebook_mode(connected=True)
plotly.io.templates.default = 'plotly_white'

In [2]:
data_dir = '../../data/movielens/'

## Load the data

In [3]:
movies = pd.read_csv(f'{data_dir}/movies_clean.csv')
ratings = pd.read_csv(f'{data_dir}/ratings_train.csv')

## Association rules

In [4]:
# parameters of the association rules
min_item_support = 0.01
min_pair_support = 0.001
min_confidence = 0.5

In [5]:
# create transaction list
# in this case a transaction is a user's every rating
transactions = set()
for _, rating_group in ratings.groupby(['user_id']):
    transactions.add(frozenset(rating_group.movie_id.tolist()))
print(len(transactions))

138493


In [6]:
# find frequent items and their frequencies
def find_frequent_items(transactions, min_support):
    item_counter = collections.Counter([item_id for transaction in transactions for item_id in transaction])
    min_count = len(transactions) * min_support
    frequent_items = {}
    for item_id, item_count in item_counter.items():
        if item_count >= min_count:
            frequent_items[item_id] = item_count
    return frequent_items
frequent_movies = find_frequent_items(transactions, min_item_support)
len(frequent_movies)

2233

In [7]:
# find frequent item pairs and their frequencies
def find_frequent_pairs(transactions, frequent_items, min_support):
    frequent_item_ids = set(frequent_items.keys())
    pair_counts = collections.defaultdict(int)
    min_count = len(transactions) * min_support
    frequent_pairs = {}
    i = 1
    for transaction in transactions:
        if i % 1000 == 0:
            if i % 10000 == 0:
                print('{0:0.1f}%'.format(100.0 * i / len(transactions)), end='')
            else:
                print('.', end='')
        i += 1
        frequent_transaction_items = transaction.intersection(frequent_item_ids)
        for pair in itertools.combinations(frequent_transaction_items, 2):
            pair_counts[pair] += 1
    for pair, pair_count in pair_counts.items():
        if pair_count >= min_count:
            frequent_pairs[pair] = pair_count
    return frequent_pairs
frequent_movie_pairs = find_frequent_pairs(transactions, frequent_movies, min_pair_support)
len(frequent_movie_pairs)

.........7.2%.........14.4%.........21.7%.........28.9%.........36.1%.........43.3%.........50.5%.........57.8%.........65.0%.........72.2%.........79.4%.........86.6%.........93.9%........

2468974

In [8]:
# calculate association rules that meet the minimum confidence criteria based on the frequent item pairs
def calculate_association_rules(frequent_items, frequent_pairs, n_transactions):
    rules = []
    for source, source_freq in frequent_items.items():
        for pair, pair_freq in frequent_pairs.items():
            if source in pair:
                target = list(set(pair).difference(set([source])))[0]
                support = 1.0 * pair_freq / n_transactions
                confidence = 1.0 * pair_freq / source_freq
                if confidence > min_confidence:
                    rules.append((source, target, support, confidence))
    return rules
rules = calculate_association_rules(frequent_movies, frequent_movie_pairs, len(transactions))
len(rules)

36331

In [9]:
# create a data frame of the rules
rules_dict = []
movie_names_map = dict(zip(movies.movie_id.tolist(), movies.title.tolist()))
for r in rules:
    rules_dict.append({
        'source_id': r[0],
        'source_name': movie_names_map[r[0]],
        'target_id': r[1],
        'target_name': movie_names_map[r[1]],
        'support': r[2],
        'confidence': r[3]})
rules_df = pd.DataFrame(rules_dict)[['source_id', 'source_name', 'target_id', 'target_name', 'support', 'confidence']]
rules_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36331 entries, 0 to 36330
Data columns (total 6 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   source_id    36331 non-null  int64  
 1   source_name  36331 non-null  object 
 2   target_id    36331 non-null  int64  
 3   target_name  36331 non-null  object 
 4   support      36331 non-null  float64
 5   confidence   36331 non-null  float64
dtypes: float64(2), int64(2), object(2)
memory usage: 1.7+ MB


In [10]:
rules_df.sample(20)

Unnamed: 0,source_id,source_name,target_id,target_name,support,confidence
14727,2657,"Rocky Horror Picture Show, T...",1270,Back to the Future (1985),0.042897,0.551215
24485,2600,eXistenZ (1999),608,Fargo (1996),0.015575,0.507768
3246,3988,How the Grinch Stole Christm...,356,Forrest Gump (1994),0.017806,0.544251
19724,8958,Ray (2004),4995,"Beautiful Mind, A (2001)",0.01299,0.553879
24702,8949,Sideways (2004),2959,Fight Club (1999),0.024889,0.572877
2634,3705,Bird on a Wire (1990),2628,Star Wars: Episode I - The P...,0.006925,0.593074
3710,2082,"Mighty Ducks, The (1992)",1580,Men in Black (a.k.a. MIB) (1...,0.015488,0.588316
7667,1431,Beverly Hills Ninja (1997),1210,Star Wars: Episode VI - Retu...,0.00886,0.523018
10209,599,"Wild Bunch, The (1969)",1265,Groundhog Day (1993),0.009358,0.501936
25313,37386,Aeon Flux (2005),589,Terminator 2: Judgment Day (...,0.007278,0.552632


In [11]:
len(rules_df.source_id.unique())

2212

In [12]:
len(rules_df.target_id.unique())

479

In [13]:
rules_df.sort_values(by='support', ascending=False).head(20)

Unnamed: 0,source_id,source_name,target_id,target_name,support,confidence
1727,296,Pulp Fiction (1994),356,Forrest Gump (1994),0.212415,0.545849
3164,356,Forrest Gump (1994),296,Pulp Fiction (1994),0.212415,0.554638
1726,296,Pulp Fiction (1994),318,"Shawshank Redemption, The (1...",0.210884,0.541916
1763,318,"Shawshank Redemption, The (1...",296,Pulp Fiction (1994),0.210884,0.575023
3163,356,Forrest Gump (1994),318,"Shawshank Redemption, The (1...",0.193822,0.50609
1764,318,"Shawshank Redemption, The (1...",356,Forrest Gump (1994),0.193822,0.528499
853,480,Jurassic Park (1993),296,Pulp Fiction (1994),0.189569,0.550145
2021,457,"Fugitive, The (1993)",480,Jurassic Park (1993),0.171539,0.598609
2019,457,"Fugitive, The (1993)",296,Pulp Fiction (1994),0.17094,0.596518
2364,150,Apollo 13 (1995),356,Forrest Gump (1994),0.167734,0.607146


In [14]:
rules_df.sort_values(by='confidence', ascending=False).head(20)

Unnamed: 0,source_id,source_name,target_id,target_name,support,confidence
16771,2034,"Black Hole, The (1979)",260,Star Wars: Episode IV - A Ne...,0.010022,0.789085
30651,68791,Terminator Salvation (2009),2571,"Matrix, The (1999)",0.009264,0.751611
35016,159,Clockers (1995),296,Pulp Fiction (1994),0.010311,0.746472
34285,544,Striking Distance (1993),589,Terminator 2: Judgment Day (...,0.009264,0.745064
23414,259,Kiss of Death (1995),296,Pulp Fiction (1994),0.008672,0.744575
11081,94864,Prometheus (2012),2571,"Matrix, The (1999)",0.008787,0.742526
6019,6934,"Matrix Revolutions, The (2003)",2571,"Matrix, The (1999)",0.052963,0.741508
15190,548,Terminal Velocity (1994),589,Terminator 2: Judgment Day (...,0.012037,0.737284
22013,5040,Conan the Destroyer (1984),2571,"Matrix, The (1999)",0.007748,0.73594
11386,7373,Hellboy (2004),2571,"Matrix, The (1999)",0.023886,0.734458


In [15]:
rules_df[rules_df.source_id == 52281].sort_values(by=['confidence', 'support'], ascending=False)

Unnamed: 0,source_id,source_name,target_id,target_name,support,confidence
13585,52281,Grindhouse (2007),2571,"Matrix, The (1999)",0.014679,0.692439
13600,52281,Grindhouse (2007),32587,Sin City (2005),0.013199,0.622616
13586,52281,Grindhouse (2007),47,Seven (a.k.a. Se7en) (1995),0.01299,0.612738
13599,52281,Grindhouse (2007),5952,Lord of the Rings: The Two T...,0.0126,0.594346
13596,52281,Grindhouse (2007),7438,Kill Bill: Vol. 2 (2004),0.01234,0.582084
13591,52281,Grindhouse (2007),1089,Reservoir Dogs (1992),0.012224,0.576635
13584,52281,Grindhouse (2007),33794,Batman Begins (2005),0.012094,0.570504
13604,52281,Grindhouse (2007),3578,Gladiator (2000),0.01177,0.555177
13587,52281,Grindhouse (2007),58559,"Dark Knight, The (2008)",0.011661,0.550068
13602,52281,Grindhouse (2007),32,Twelve Monkeys (a.k.a. 12 Mo...,0.011632,0.548706


In [17]:
import csv
rules_df.to_csv(f'{data_dir}/association_rules.csv', index=False, quoting=csv.QUOTE_NONNUMERIC)