# Synthetic Group Generation

This notebook generate synthetic group based on similiraty/disimilarity of embeddings of users. The similar, divergent and opossing (2 opposite subgroups) groups are generated.

In [None]:
# IMPORTS

import mlflow
import torch
from datasets import EchoNestLoader, LastFm1kLoader, DataLoader
from utils import Utils
from models import ELSA
import tqdm
import numpy as np
from torch.nn import functional as F
import random
import os
import plotly.express as px
import pandas as pd

device = Utils.set_device()

In [None]:
# constants

Td_QUANTILE = 0.9
Ts_QUANTILE = 0.1

SIMILAR_GROUPS = [3,5]
DIVIDE_GROUPS = [3,5]
OPPOSING_GROUPS = [[2,1], [3,2], [4,1]]

FINAL_GROUP_COUNT = 75

MAX_RECOMMENDATION_OVERLAP = 0.8

DATASET = 'LastFM1k' # 'LastFM1k' or 'EchoNest'
MODEL_RUN_ID = '32b65a3a9edf4ff4b46e9d8385d93bc4'

OUT_DIR = './data/synthetic_groups'

## Load run, dataset, model

In [None]:
# load run

run = mlflow.get_run(MODEL_RUN_ID)
artifact_path = run.info.artifact_uri.replace('file://', '') # type: ignore
params = run.data.params

assert params['model'] == 'ELSA', 'Model from run is not ELSA -> not supported'
assert params['dataset'] == DATASET, 'Dataset from run is not the same as the current dataset -> not supported'

In [None]:
# load model

items = int(params['items'])
factors = int(params['factors'])

model = ELSA(items, factors).to(device)
optimizer = torch.optim.Adam(model.parameters()) # not used, but needed for loading
Utils.load_checkpoint(model, optimizer, f'{artifact_path}/checkpoint.ckpt')
model.eval()

In [None]:
# load dataset

class Config:
    val_ratio = 0.1
    test_ratio = 0.1

#Load dataset
match DATASET:
    case 'EchoNest':
        dataset_loader = EchoNestLoader()
    case 'LastFM1k':
        dataset_loader = LastFm1kLoader()
    case _:
        raise ValueError(f'Dataset {args.dataset} not supported. Check typos.')
dataset_loader.prepare(Config())

## Compute similarity matrix and thresholds

In [None]:
# load interactions
csr_interactions = dataset_loader.csr_interactions
interactions_batches = DataLoader(csr_interactions, batch_size=1024, device=device, shuffle=False)

# create user embeddings
batches_embeddings = []
for batch in tqdm.tqdm(interactions_batches, desc='Creating user embeddings'):
    batch_embeddings = model.encode(batch)
    batches_embeddings.append(batch_embeddings)
user_embeddings = torch.cat(batches_embeddings)

In [None]:
# compute similarity matrix
normalized_embeddings = F.normalize(user_embeddings, p=2, dim=1)
similarity_matrix = torch.matmul(normalized_embeddings, normalized_embeddings.T)
similarity_matrix = (similarity_matrix + 1) / 2 # opposite user embeddings are completely opposite, so we normalize to [0,1]

print(f'Similarity matrix shape: {similarity_matrix.shape}')

In [None]:
# compute thresholds
similarity_values = torch.triu(similarity_matrix, diagonal=1).flatten()
similarity_values = similarity_values[similarity_values != 0]

Td = similarity_values.quantile(Td_QUANTILE)
Ts = similarity_values.quantile(Ts_QUANTILE)

print(f'Td: {Td}, Ts: {Ts}')

## Find groups

In [None]:

def is_user_similar(group_members, user):
    if not group_members:
        return True
    mean_similarity = similarity_matrix[group_members, user].mean()
    return mean_similarity >= Td

def _similar_group(group_size):
    user_count = similarity_matrix.shape[0]
    group_members = [random.randint(0, user_count-1)]
    rest = list(set(range(user_count)) - set(group_members))
    for _ in range(1_000): # it can be blocked in a loop if the group does not exists
        user = random.choice(rest)
        if is_user_similar(group_members, user):
            group_members.append(user)
            rest.remove(user)
        if len(group_members) >= group_size:
            return group_members
    raise TimeoutError('Could not find similar group')

def similar_group(group_size):
    while True:
        try:
            return _similar_group(group_size)
        except TimeoutError:
            pass
        

In [None]:
def is_user_divergent(group_members, user):
    if not group_members:
        return True
    mean_similarity = similarity_matrix[group_members, user].mean()
    return mean_similarity <= Ts

def _divergent_group(group_size):
    user_count = similarity_matrix.shape[0]
    group_members = [random.randint(0, user_count-1)]
    rest = list(set(range(user_count)) - set(group_members))
    for _ in range(1_000): # it can be blocked in a loop if the group does not exists
        user = random.choice(rest)
        if is_user_divergent(group_members, user):
            group_members.append(user)
            rest.remove(user)
        if len(group_members) >= group_size:
            return group_members
    raise TimeoutError('Could not find divergent group')

def divergent_group(group_size):
    while True:
        try:
            return _divergent_group(group_size)
        except TimeoutError:
            pass

In [None]:
from copy import deepcopy

def _opposing_group(group_size):
    user_count = similarity_matrix.shape[0]
    groups_lefts = deepcopy(group_size)
    
    # choose the first user
    group_members = ([random.randint(0, user_count-1)], [])
    groups_lefts[0] -= 1
    
    rest = list(set(range(user_count)) - set(group_members[0]))
    
    index = 1
    for _ in range(1_000):
        # choose the subgroup to expand
        group_to_expand = index % 2
        if groups_lefts[group_to_expand] == 0:
            index += 1
            group_to_expand = index % 2
            
        user = random.choice(rest)
        if is_user_similar(group_members[group_to_expand], user) and is_user_divergent(group_members[1-group_to_expand], user):
            group_members[group_to_expand].append(user)
            groups_lefts[group_to_expand] -= 1
            rest.remove(user)
            
        if groups_lefts[0] == 0 and groups_lefts[1] == 0:
            return group_members[0] + group_members[1]
        index += 1
    raise TimeoutError('Could not find opposing group')

def opposing_group(group_size):
    while True:
        try:
            return _opposing_group(group_size)
        except TimeoutError:
            pass
    

In [None]:
def overall_similarity(group):
    overall_similarity = torch.triu(similarity_matrix[group][:, group], diagonal=1).flatten()
    overall_similarity = overall_similarity[overall_similarity != 0].mean()
    return overall_similarity.item()

def subgroup_similarity(group, subgroup_size):
    sub1, sub2 = group[:subgroup_size[0]], group[subgroup_size[0]:]
    
    sub1_similarity = torch.triu(similarity_matrix[sub1][:, sub1], diagonal=1).flatten()
    sub1_similarity = sub1_similarity[sub1_similarity != 0].mean()
    
    sub2_similarity = torch.triu(similarity_matrix[sub2][:, sub2], diagonal=1).flatten()
    sub2_similarity = sub2_similarity[sub2_similarity != 0].mean()
    
    return sub1_similarity.item(), sub2_similarity.item()

In [None]:
%%time
similar_groups = {}
for group_size in SIMILAR_GROUPS:
    similar_groups[group_size] = [similar_group(group_size) for _ in range(FINAL_GROUP_COUNT)]

In [None]:
%time
divergent_groups = {}
for group_size in DIVIDE_GROUPS:
    divergent_groups[group_size] = [divergent_group(group_size) for _ in range(FINAL_GROUP_COUNT)]

In [None]:
opposing_groups = {}
for group_size in OPPOSING_GROUPS:
    opposing_groups[tuple(group_size)] = [opposing_group(group_size) for _ in range(FINAL_GROUP_COUNT)]

## Plot Histograms

In [None]:
# show histogram for overall similarity for similar groups
from plotly import express as px

similar_groups_overall_similarity = {group_size: [overall_similarity(group) for group in groups] for group_size, groups in similar_groups.items()}

px.histogram(x=similar_groups_overall_similarity[3], nbins=50, title='Overall similarity for similar groups')
# px.histogram(x=similar_groups_overall_similarity[5], nbins=50, title='Overall similarity for similar groups')

In [None]:
divergent_groups_overall_similarity = {group_size: [overall_similarity(group) for group in groups] for group_size, groups in divergent_groups.items()}

# px.histogram(x=divergent_groups_overall_similarity[3], nbins=50, title='Overall similarity for divergent groups')
px.histogram(x=divergent_groups_overall_similarity[5], nbins=50, title='Overall similarity for divergent groups')

In [None]:
opposing_groups_overall_similarity = {group_size: [overall_similarity(group) for group in groups] for group_size, groups in opposing_groups.items()}

px.histogram(x=opposing_groups_overall_similarity[(2,1)], nbins=50, title='Overall similarity for opposing groups (2,1)')
px.histogram(x=opposing_groups_overall_similarity[(3,2)], nbins=50, title='Overall similarity for opposing groups (3,2)')
px.histogram(x=opposing_groups_overall_similarity[(4,1)], nbins=50, title='Overall similarity for opposing groups (4,1)')

# subgroup similarity
opposing_groups_subgroup_similarity = {group_size: [subgroup_similarity(group, group_size)[1] for group in groups] for group_size, groups in opposing_groups.items()}
px.histogram(x=opposing_groups_subgroup_similarity[(3,2)], nbins=50, title='Subgroup similarity for opposing groups (2,1)')



In [None]:
def group_similarity(group1, group2):
    '''Return number of common users'''
    return len(set(group1) & set(group2)) / len(group1)

def show_too_similar_groups(groups, threshold):
    index = 0
    for i, group1 in enumerate(groups):
        for group2 in groups[i+1:]:
            similarity = group_similarity(group1, group2)
            if similarity >= threshold:
                print(f'{index}: Groups {group1} and {group2} are too similar: {similarity}')
                index += 1

In [None]:
show_too_similar_groups(similar_groups[5], 0.3)

In [None]:
# save groups as numpy arrays

out_path = f'{OUT_DIR}/{DATASET}'
os.makedirs(out_path, exist_ok=True)

for group_size, groups in similar_groups.items():
    np.save(f'{out_path}/similar_{group_size}.npy', np.array(groups))
    
for group_size, groups in divergent_groups.items():
    np.save(f'{out_path}/divergent_{group_size}.npy', np.array(groups))
    
for group_size, groups in opposing_groups.items():
    np.save(f'{out_path}/opposing_{group_size[0]}_{group_size[1]}.npy', np.array(groups))

