## Check number of parameters and relevant info for all models

In [7]:
import os
import sys
import importlib
import torch
import yaml
import numpy as np
import pandas as pd
from pathlib import Path
import scipy.sparse as sp

# Add src to path
sys.path.insert(0, 'src')

# ============================================================================
# CONFIGURATION SETUP
# ============================================================================

# Load base configs
with open('src/configs/overall.yaml', 'r') as f:
    base_config = yaml.safe_load(f)

with open('src/configs/dataset/baby.yaml', 'r') as f:
    dataset_config = yaml.safe_load(f)

config = {**base_config, **dataset_config}
config['dataset'] = 'baby'
config['data_path'] = 'data/'
config['device'] = torch.device('cpu')

# ============================================================================
# MOCK DATASET AND DATALOADER
# ============================================================================

# Load actual baby dataset for inter_matrix
data_path = 'data/baby'
inter_df = pd.read_csv(f'{data_path}/baby.inter', sep='\t')

class MockDataset:
    """Mock dataset that provides interaction matrix and basic statistics"""
    def __init__(self, inter_df):
        self.user_num = int(inter_df['userID'].max()) + 1
        self.item_num = int(inter_df['itemID'].max()) + 1
        self.df = inter_df
        
    def get_user_num(self):
        return self.user_num
    
    def get_item_num(self):
        return self.item_num
    
    def inter_matrix(self, form='coo'):
        """Generate sparse interaction matrix"""
        users = self.df['userID'].values
        items = self.df['itemID'].values
        data = np.ones(len(users))
        
        if form == 'csr':
            return sp.csr_matrix((data, (users, items)), shape=(self.user_num, self.item_num))
        else:  # Default to 'coo'
            return sp.coo_matrix((data, (users, items)), shape=(self.user_num, self.item_num))

class MockDataLoader:
    """Mock dataloader that wraps dataset and provides inter_matrix method"""
    def __init__(self, dataset):
        self.dataset = dataset
    
    def inter_matrix(self, form='coo'):
        """Pass-through to dataset's inter_matrix method"""
        return self.dataset.inter_matrix(form)

mock_dataset = MockDataset(inter_df)
mock_dataloader = MockDataLoader(mock_dataset)

# ============================================================================
# MODEL CLASS NAME MAPPING
# ============================================================================

# Map filename to actual class name for models with non-standard naming
CLASS_NAME_MAP = {
    'dualgnn': 'DualGNN',
    'itemknncbf': 'ItemKNNCBF',
    'lgmrec': 'LGMRec',
    'lightgcn': 'LightGCN',
    'selfcfed_lgn': 'SELFCFED_LGN',
    'slmrec': 'SLMRec',
}

# ============================================================================
# DEFAULT PARAMETERS
# ============================================================================

# Common parameters with sensible defaults
DEFAULT_PARAMS = {
    'lambda_coeff': 0.1,
    'cf_model': 'mf',
    'n_layers': 2,
    'dropout': 0.1,
    'reg_weight': 1e-4,
    'cl_weight': 0.1,
    'temperature': 0.2,
    'ssl_tau': 0.5,
    'ssl_reg': 0.1,
    'hyper_num': 64,
    'n_ui_layers': 2,
    'n_mm_layers': 2,
    'n_layers_feat': 1,
    'knn_k': 10,
    'mm_image_weight': 0.5,
    'aggr_mode': 'mean',
    'degree_ratio': 0.5,
    'cl_loss': 0.01,
    'dropout_rate': 0.1,
    'image_knn_k': 10,
    'text_knn_k': 10,
    'feat_embed_dim': 64,
}

# Parameters that should remain as lists (don't take first element)
LIST_PARAMS = {
    'weight_size',
    'mess_dropout',
    'hyper_parameters',
}

# Models that require special handling
MODELS_TO_SKIP = {
    'layergcn': 'Imports from models.common (circular import)',
    'pgl': 'Requires sparsesvd package',
    'damrs': 'Requires preprocessed graph files (item_graph_dict_2.npy)',
}

# ============================================================================
# MODEL ANALYSIS
# ============================================================================

models_dir = Path('src/models')
model_files = [f.stem for f in models_dir.glob('*.py') 
               if f.stem not in ['__init__', '__pycache__']]

print(f"Found {len(model_files)} models in src/models/")
print("=" * 90)

model_info = []
skipped_models = []

for model_name in sorted(model_files):
    # Skip known problematic models
    if model_name in MODELS_TO_SKIP:
        reason = MODELS_TO_SKIP[model_name]
        print(f"âŠ˜ {model_name.upper():20s} | Skipped: {reason}")
        skipped_models.append((model_name, reason))
        continue
    
    try:
        # Import the model module
        module = importlib.import_module(f'models.{model_name}')
        
        # Try to find the model class
        # First check if we have a mapping
        if model_name in CLASS_NAME_MAP:
            class_name = CLASS_NAME_MAP[model_name]
        else:
            # Try standard naming conventions
            possible_names = [
                model_name.upper(),  # UPPERCASE
                ''.join([word.capitalize() for word in model_name.split('_')]),  # CamelCase
                model_name,  # lowercase
            ]
            class_name = None
            for name in possible_names:
                if hasattr(module, name):
                    class_name = name
                    break
        
        if class_name is None or not hasattr(module, class_name):
            print(f"âœ— {model_name.upper():20s} | Error: Could not find model class")
            continue
        
        model_class = getattr(module, class_name)
        
        # Build config for this model
        test_config = config.copy()
        
        # Load model-specific config if exists
        model_config_path = f'src/configs/model/{model_name}.yaml'
        if os.path.exists(model_config_path):
            with open(model_config_path, 'r') as f:
                model_specific_config = yaml.safe_load(f)
            test_config.update(model_specific_config)
        
        # Add model name
        test_config['model'] = model_name
        
        # Add default parameters for any missing keys
        for key, value in DEFAULT_PARAMS.items():
            if key not in test_config:
                test_config[key] = value
        
        # Handle list parameters (take first value for instantiation, except for LIST_PARAMS)
        for key, value in test_config.items():
            if isinstance(value, list) and len(value) > 0 and key not in LIST_PARAMS:
                test_config[key] = value[0]
        
        # Instantiate model
        model = model_class(test_config, mock_dataloader)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params
        
        # Analyze model characteristics
        is_multimodal = (hasattr(model, 'v_feat') or hasattr(model, 't_feat') or 
                        'mm' in model_name.lower() or 'multimodal' in model_name.lower())
        has_gcn = ('gcn' in model_name.lower() or 'gnn' in model_name.lower() or 
                  'graph' in model_name.lower() or model_name.lower() == 'lightgcn')
        
        # Determine complexity
        if has_gcn and is_multimodal:
            complexity = 'High'
        elif has_gcn or is_multimodal:
            complexity = 'Medium'
        else:
            complexity = 'Low'
        
        model_info.append({
            'model': model_name.upper(),
            'total_params': total_params,
            'trainable_params': trainable_params,
            'frozen_params': frozen_params,
            'is_multimodal': 'âœ“' if is_multimodal else 'âœ—',
            'has_gcn': 'âœ“' if has_gcn else 'âœ—',
            'complexity': complexity
        })
        
        print(f"âœ“ {model_name.upper():20s} | {total_params:11,} params ({total_params/1e6:5.2f}M)")
        
    except Exception as e:
        error_msg = str(e)
        # Truncate long error messages
        if len(error_msg) > 60:
            error_msg = error_msg[:57] + "..."
        print(f"âœ— {model_name.upper():20s} | Error: {error_msg}")

print("=" * 90)
print(f"\nSuccessfully analyzed: {len(model_info)} / {len(model_files)} models")
print(f"Skipped (known issues): {len(skipped_models)}")
print(f"Failed (unexpected): {len(model_files) - len(model_info) - len(skipped_models)}\n")

# ============================================================================
# RESULTS TABLE
# ============================================================================

if model_info:
    df_models = pd.DataFrame(model_info)
    df_models = df_models.sort_values('total_params')
    df_models['params_M'] = df_models['total_params'] / 1e6
    
    print("=" * 100)
    print("MODEL COMPARISON (Sorted by parameter count - Lower = Typically Faster)")
    print("=" * 100)
    print(f"{'Rank':>5} {'Model':<12} {'Params (M)':>10}  {'Trainable':>11}  "
          f"{'Frozen':>8} {'Multimodal':>11} {'GCN':>4} {'Complexity':>11}")
    
    for rank, (_, row) in enumerate(df_models.iterrows(), 1):
        print(f"{rank:>5} {row['model']:<12} {row['params_M']:>10.2f} "
              f"{row['trainable_params']:>11,} {row['frozen_params']:>8,} "
              f"{row['is_multimodal']:>11} {row['has_gcn']:>4} {row['complexity']:>11}")
    
    print("=" * 100)
    
    # Summary statistics
    print(f"\nðŸ“Š SUMMARY STATISTICS:")
    print(f"   Total models analyzed: {len(df_models)}")
    print(f"   Smallest model: {df_models.iloc[0]['model']} ({df_models.iloc[0]['params_M']:.2f}M params)")
    print(f"   Largest model:  {df_models.iloc[-1]['model']} ({df_models.iloc[-1]['params_M']:.2f}M params)")
    print(f"   Average size:   {df_models['params_M'].mean():.2f}M params")
    print(f"   Multimodal models: {(df_models['is_multimodal'] == 'âœ“').sum()}")
    print(f"   GCN/GNN-based: {(df_models['has_gcn'] == 'âœ“').sum()}")
    
    # Speed ranking by complexity
    print(f"\nâš¡ TRAINING SPEED RANKING:")
    
    for complexity_level, emoji, desc in [
        ('Low', 'ðŸŸ¢', 'FASTEST (Simple architectures)'),
        ('Medium', 'ðŸŸ¡', 'MODERATE (GCN or multimodal)'),
        ('High', 'ðŸ”´', 'SLOWEST (GCN + multimodal)')
    ]:
        subset = df_models[df_models['complexity'] == complexity_level]
        if len(subset) > 0:
            print(f"\n   {emoji} {desc}:")
            for _, row in subset.iterrows():
                print(f"      â€¢ {row['model']:<12} {row['params_M']:6.2f}M params")
    
    # Recommendations
    print(f"\nðŸŽ¯ RECOMMENDATIONS FOR MUSIC4ALL DATASET:")
    
    # Get top 3 fastest
    top3 = df_models.head(3)
    print(f"\n   Fast iteration models (< 5M params):")
    for _, row in top3.iterrows():
        if row['params_M'] < 5:
            print(f"      â€¢ {row['model']:<12} {row['params_M']:5.2f}M params - "
                  f"{'Multimodal' if row['is_multimodal'] == 'âœ“' else 'CF-only'}")
    
    # GCN models
    gcn_models = df_models[df_models['has_gcn'] == 'âœ“'].head(3)
    if len(gcn_models) > 0:
        print(f"\n   Graph-based models (best performance/speed trade-off):")
        for _, row in gcn_models.iterrows():
            print(f"      â€¢ {row['model']:<12} {row['params_M']:5.2f}M params")
    
    print(f"\nðŸ’¡ TRAINING SPEED FACTORS:")
    print(f"   1. Parameter count: Direct impact on computation time")
    print(f"   2. Graph operations: GCN/GNN adds ~2-3x overhead")
    print(f"   3. Multimodal fusion: Feature processing adds time")
    print(f"   4. Batch size: Larger batches improve GPU efficiency")
    
    # Export to CSV for reference
    df_models[['model', 'params_M', 'complexity', 'is_multimodal', 'has_gcn']].to_csv(
        'model_comparison.csv', index=False)
    print(f"\nâœ“ Results saved to: model_comparison.csv")

Found 20 models in src/models/
âœ“ BM3                  |  33,570,688 params (33.57M)
âœ“ BPR                  |   1,695,680 params ( 1.70M)
âŠ˜ DAMRS                | Skipped: Requires preprocessed graph files (item_graph_dict_2.npy)
âœ“ BM3                  |  33,570,688 params (33.57M)
âœ“ BPR                  |   1,695,680 params ( 1.70M)
âŠ˜ DAMRS                | Skipped: Requires preprocessed graph files (item_graph_dict_2.npy)


  self.mm_adj = torch.load(mm_adj_file)


âœ“ DRAGON               |  37,305,214 params (37.31M)
âœ“ DUALGNN              |   5,438,462 params ( 5.44M)
âœ“ DUALGNN              |   5,438,462 params ( 5.44M)


  self.mm_adj = torch.load(mm_adj_file)


âœ“ FREEDOM              |  33,566,528 params (33.57M)
âœ“ GRCN                 |   4,524,478 params ( 4.52M)
âœ“ GRCN                 |   4,524,478 params ( 4.52M)
âœ“ ITEMKNNCBF           |           2 params ( 0.00M)
âœ“ ITEMKNNCBF           |           2 params ( 0.00M)


  image_adj = torch.load(image_adj_file)
  text_adj = torch.load(text_adj_file)
  text_adj = torch.load(text_adj_file)


âœ“ LATTICE              |  33,566,530 params (33.57M)
âŠ˜ LAYERGCN             | Skipped: Imports from models.common (circular import)
âœ“ LGMREC               |  33,584,320 params (33.58M)
âœ“ LGMREC               |  33,584,320 params (33.58M)
âœ“ LIGHTGCN             |   1,695,680 params ( 1.70M)
âœ“ LIGHTGCN             |   1,695,680 params ( 1.70M)


  image_adj = torch.load(image_adj_file)
  text_adj = torch.load(text_adj_file)


âœ“ MGCN                 |  33,587,392 params (33.59M)
âœ“ MMGCN                |  15,558,720 params (15.56M)
âœ“ MMGCN                |  15,558,720 params (15.56M)
âœ“ MVGAE                |     756,352 params ( 0.76M)
âŠ˜ PGL                  | Skipped: Requires sparsesvd package
âœ“ MVGAE                |     756,352 params ( 0.76M)
âŠ˜ PGL                  | Skipped: Requires sparsesvd package
âœ“ SELFCFED_LGN         |   1,699,840 params ( 1.70M)
âœ“ SELFCFED_LGN         |   1,699,840 params ( 1.70M)
use the pre adjcency matrix
âœ“ SLMREC               |   2,028,032 params ( 2.03M)
use the pre adjcency matrix
âœ“ SLMREC               |   2,028,032 params ( 2.03M)


  image_adj = torch.load(image_adj_file)
  text_adj = torch.load(text_adj_file)


âœ“ SMORE                |  33,608,198 params (33.61M)
âœ“ VBPR                 |   3,226,944 params ( 3.23M)

Successfully analyzed: 17 / 20 models
Skipped (known issues): 3
Failed (unexpected): 0

MODEL COMPARISON (Sorted by parameter count - Lower = Typically Faster)
 Rank Model        Params (M)    Trainable    Frozen  Multimodal  GCN  Complexity
    1 ITEMKNNCBF         0.00           2        0           âœ“    âœ—      Medium
    2 MVGAE              0.76     756,352        0           âœ“    âœ—      Medium
    3 BPR                1.70   1,695,680        0           âœ“    âœ—      Medium
    4 LIGHTGCN           1.70   1,695,680        0           âœ“    âœ“        High
    5 SELFCFED_LGN       1.70   1,699,840        0           âœ“    âœ—      Medium
    6 SLMREC             2.03   2,028,032        0           âœ“    âœ—      Medium
    7 VBPR               3.23   3,226,944        0           âœ“    âœ—      Medium
    8 GRCN               4.52   4,524,478        0         

## Inspect `data/baby`

In [10]:
import yaml
import numpy as np
import pandas as pd

def inspect_dataset(name, data_path, config_path):
    with open(config_path, 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)

    # Load inter_file
    inter_file_path = f'{data_path}/{config["inter_file_name"]}'
    image_file_path = f'{data_path}/{config["vision_feature_file"]}'
    text_file_path = f'{data_path}/{config["text_feature_file"]}'

    inter_df = pd.read_csv(inter_file_path, sep='\t')
    image_feat = np.load(image_file_path, allow_pickle=True)
    text_feat = np.load(text_file_path, allow_pickle=True)
    
    # Inspect the dataframe
    print(f'--- Inspecting dataset: {name} ---')
    print(f'Columns: {inter_df.columns.tolist()}')
    print(inter_df[['userID', 'itemID', 'x_label']].tail())
    print(f'#users: {len(inter_df["userID"].unique()):6}, max: {inter_df["userID"].max():6}')
    print(f'#items: {len(inter_df["itemID"].unique()):6}, max: {inter_df["itemID"].max():6}')
    print(f'#filtered interactions: {len(inter_df)}\n')

    print(f'Shape of image_feat: {image_feat.shape}')
    print(f'Shape of text_feat : {text_feat.shape}')
    print('-----------------------------------\n')

dataset_name = 'baby'
data_path = f'data/{dataset_name}'
config_path = f'src/configs/dataset/{dataset_name}.yaml'
inspect_dataset(dataset_name, data_path, config_path)

--- Inspecting dataset: baby ---
Columns: ['userID', 'itemID', 'rating', 'timestamp', 'x_label']
        userID  itemID  x_label
160787   19444    7022        0
160788   19444    6959        0
160789   19444    7005        0
160790   19444    7023        1
160791   19444    6994        2
#users:  19445, max:  19444
#items:   7050, max:   7049
#filtered interactions: 160792

Shape of image_feat: (7050, 4096)
Shape of text_feat : (7050, 384)
-----------------------------------



In [14]:
dataset_name = 'Music4All'
# Now using junction - simpler path!
data_path = f'data/{dataset_name}'  # Junction points to Dragon-for-Music/data/Music4All
config_path = f'src/configs/dataset/{dataset_name}.yaml'
inspect_dataset(dataset_name, data_path, config_path)

--- Inspecting dataset: Music4All ---
Columns: ['userID', 'itemID', 'x_label']
         userID  itemID  x_label
5058229   14124   14553        2
5058230   14124    7313        2
5058231   14124   62596        2
5058232   14124   63684        2
5058233   14124   14553        2
#users:  14125, max:  14124
#items:  80735, max:  80734
#filtered interactions: 5058234

Shape of image_feat: (80735, 1024)
Shape of text_feat : (80735, 384)
-----------------------------------



# Content below was copied from `Dragon-for-Music`

## Music4All features - Sanity check

In [1]:
import os
import numpy as np
import pandas as pd
import yaml

data_path = 'data/Music4All'
config_path = f'configs/dataset/Music4All.yaml'

with open(config_path, 'r', encoding='utf-8') as file:
    config = yaml.safe_load(file)

# Load interaction file and features
inter_file_path = os.path.join(data_path, config['inter_file_name'])
text_feature_path = os.path.join(data_path, config['text_feature_file'])            # text_feat.npy -> clean_
audio_feature_path = os.path.join(data_path, config['vision_feature_file'])         # audio_feat_mert.npy -> clean_

inter_df = pd.read_csv(inter_file_path, sep='\t')
text_feature = np.load(text_feature_path)
audio_feature = np.load(audio_feature_path)

# Inspect the dataframe
print(inter_df[['userID', 'itemID', 'x_label']].tail())
print(f'#users: {inter_df["userID"].nunique()}, max: {inter_df["userID"].max()}')
print(f'#items: {inter_df["itemID"].nunique()}, max: {inter_df["itemID"].max()}')
print(f'#filtered interactions: {len(inter_df)}')

print(f'Shape of text_feature : {text_feature.shape}')
print(f'Shape of audio_feature: {audio_feature.shape}')

         userID  itemID  x_label
5058229   14126   19690        2
5058230   14126    9878        2
5058231   14126   84677        2
5058232   14126   86170        2
5058233   14126   19690        2
#users: 14125, max: 14126
#items: 80735, max: 109268
#filtered interactions: 5058234
Shape of text_feature : (109269, 384)
Shape of audio_feature: (109269, 1024)


### Clean ids

In [2]:
# Keep only items with interactions
inter_df['itemID'], unique_item_list = pd.factorize(inter_df['itemID'], sort=True)
unique_item_list = unique_item_list.tolist()

clean_text_feature = text_feature[unique_item_list]
clean_audio_feature = audio_feature[unique_item_list]

np.save(f"{data_path}/clean_text_feat.npy", clean_text_feature)
np.save(f"{data_path}/clean_audio_feat.npy", clean_audio_feature)

In [4]:
# Make sure #users = max+1 (no skipping ids)
inter_df['userID'], uniques = pd.factorize(inter_df['userID'], sort=True)
print(inter_df.tail())
print(f'#users: {inter_df["userID"].nunique()}, max: {inter_df["userID"].max()}')
print(f'#items: {inter_df["itemID"].nunique()}, max: {inter_df["itemID"].max()}')
print(f'#filtered interactions: {len(inter_df)}')

inter_df.to_csv(f"{data_path}/clean_filtered_interactions.csv", index=False, sep='\t')

## Music4All: Audio features

See `tools.m4a_data_prep.prep_audio_features`

## Music4All: Text features
Activate the `sbert` conda env for this section

In [None]:
import os
from tools.m4a_data_prep import load_json_from_gcs, prep_text_features

music4all_dir = 'gs://music4all/Music4All/'
attr_path = os.path.join(music4all_dir, 'processed/attributes.json')
val_path = os.path.join(music4all_dir, 'processed/attribute_values.json')

attributes = load_json_from_gcs(attr_path)
val = load_json_from_gcs(val_path)

# text_feat = prep_text_features()

In [22]:
print(f'val={val.keys()}')
print(f'Genres: count={len(val["genre"]):5}; {val["genre"][:50]}')
print(f'Tags  : count={len(val["tag"]):5}; {val["tag"][:50]}')

val=dict_keys(['artist', 'album_name', 'lang', 'release', 'key', 'mode', 'genre', 'tag'])
Genres: count=  853; ['8-bit', 'a cappella', 'abstract', 'abstract hip hop', 'accordion', 'acid house', 'acid jazz', 'acid techno', 'acousmatic', 'acoustic blues', 'acoustic pop', 'adoracao', 'adventista', 'afrikaans', 'afro-funk', 'afrobeat', 'afropop', 'aggrotech', 'albanian pop', 'album rock', 'alternative country', 'alternative dance', 'alternative hip hop', 'alternative metal', 'alternative metalcore', 'alternative pop', 'alternative pop rock', 'alternative rock', 'ambient', 'ambient folk', 'ambient industrial', 'ambient techno', 'anadolu rock', 'anarcho-punk', 'anime', 'anthem', 'anti-folk', 'arabesk', 'armenian folk', 'armenian pop', 'art pop', 'art rock', 'asmr', 'atmosphere', 'atmospheric black metal', 'atmospheric doom', 'atmospheric sludge', 'australian rock', 'austropop', 'avant-garde']
Tags  : count=19541; [' ', ' ambient', ' blues rock', ' classic rock', ' dance', ' dark ambient', ' 

In [29]:
print(attributes['items'][4589])

{'song_id': '2Z3mv11Pg2Xdj3Fp', 'artist': 14998, 'album_name': 12017, 'lang': 11, 'release': 85, 'key': 7, 'mode': 1, 'genre': [338, 480, 337, 684], 'tag': [7029, 10381, 370, 7028, 7024, 14907]}


In [31]:
print(val['artist'][14998])
print(val['album_name'][12017])
print(val['release'][85])
for i in [338, 480, 337, 684]:
    print(val['genre'][i])
for i in [7029, 10381, 370, 7028, 7024, 14907]:
    print(val['tag'][i])

Ty Segall
Freedom's Goblin
2018
garage rock
lo-fi
garage punk
rock
garage rock
lo-fi
2018
garage punk
garage
rock


## Music4All: Interaction data

**`interaction.json`**
```
{
  0: [12233, 23344, ...],
  1: [],
  ...
}
```


In [4]:
import os
from tools.m4a_data_prep import prep_interaction_data

music4all_dir = 'gs://music4all/Music4All/'
inter_path = os.path.join(music4all_dir, 'processed/interactions.json')
attr_path = os.path.join(music4all_dir, 'processed/attributes.json')

df = prep_interaction_data(inter_path, attr_path)

# Inspect the dataframe
print(df.tail())
print(f'#users: {len(df["userID"].unique())}')
print(f'#items: {len(df["itemID"].unique())}')
print(f'#filtered interactions: {len(df)}')

        userID  itemID  x_label
5058229  14126   19690        2
5058230  14126    9878        2
5058231  14126   84677        2
5058232  14126   86170        2
5058233  14126   19690        2
#users: 14125
#items: 80735
#filtered interactions: 5058234
        userID  itemID  x_label
5058229  14126   19690        2
5058230  14126    9878        2
5058231  14126   84677        2
5058232  14126   86170        2
5058233  14126   19690        2
#users: 14125
#items: 80735
#filtered interactions: 5058234


### Check if `interactions.json` preserves order from `listening_history.csv`

In [None]:
from gcs_utils import read_tsv_from_gcs

hist_path = os.path.join(music4all_dir, 'listening_history.csv')
df_listening_history = read_tsv_from_gcs(hist_path)

print(df_listening_history.head(5))

            user              song         timestamp
0  user_007XIjOr  DaTQ53TUmfP93FSr  2019-02-20 12:28
1  user_007XIjOr  dGeyvi5WCOjDU7da  2019-02-20 12:35
2  user_007XIjOr  qUm54NYOjeFhmKYx  2019-02-20 12:48
3  user_007XIjOr  FtnuMT1DlevSR2n5  2019-02-20 12:52
4  user_007XIjOr  LHETTZcSZLeaVOGh  2019-02-20 13:09


In [None]:
import pandas as pd
import json
import random
from tools.m4a_data_prep import load_json_from_gcs

inter_path = 'gs://music4all/Music4All/processed/interactions.json'
interactions = load_json_from_gcs(inter_path)

def first_mismatch_with_context(list1, list2, window=10):
    """Return index, values, and context slices around the first mismatch."""
    for i, (a, b) in enumerate(zip(list1, list2)):
        if a != b:
            half = window // 2
            start = max(0, i - half)
            end = i + half + 1
            return i, list1[start:end], list2[start:end]
    if len(list1) != len(list2):
        i = min(len(list1), len(list2))
        half = window // 2
        start = max(0, i - half)
        end = i + half + 1
        return i, list1[start:end], list2[start:end]
    return None

df = df_listening_history.sort_values("timestamp").reset_index(drop=True)
# df = df_listening_history

# Build lookup dicts for encoding
user_to_idx = {u["user_id"]: idx for idx, u in enumerate(attributes["users"])}
item_to_idx = {i["song_id"]: idx for idx, i in enumerate(attributes["items"])}

# === Randomly sample 20 users ===
all_users = df["user"].unique().tolist()
sample_users = random.sample(all_users, min(20, len(all_users)))

print("Checking order preservation for 20 random users...\n")

for user_id in sample_users:
    user_idx = user_to_idx[user_id]

    # from CSV (chronological order)
    user_df = df[df["user"] == user_id]
    csv_songs = [item_to_idx[song] for song in user_df["song"].tolist()]

    # from JSON
    json_songs = interactions.get(str(user_idx), [])

    # compare
    match = csv_songs == json_songs
    print(f"User {user_id} (encoded {user_idx}): "
        f"{'OK' if match else 'MISMATCH'}")

    if not match:
        mm = first_mismatch_with_context(csv_songs, json_songs, window=10)
        if mm:
            idx, csv_context, json_context = mm
            print(f"  First mismatch at index {idx}:")
            print(f"    CSV context : {csv_context}")
            print(f"    JSON context: {json_context}")
        print(f"  CSV length={len(csv_songs)}, JSON length={len(json_songs)}\n")


Checking order preservation for 20 random users...

User user_z4gXXNYv (encoded 13924): MISMATCH
  First mismatch at index 203:
    CSV context : [36963, 33688, 45464, 60178, 18463, 43264, 18463, 105599, 36364, 87603, 23975]
    JSON context: [36963, 33688, 45464, 60178, 18463, 18463, 43264, 105599, 36364, 87603, 23975]
  CSV length=386, JSON length=386

User user_evGC0u2F (encoded 9315): MISMATCH
  First mismatch at index 94:
    CSV context : [4478, 14231, 47004, 29618, 35792, 5440, 53766, 79581, 102157, 6173, 102157]
    JSON context: [4478, 14231, 47004, 29618, 35792, 53766, 5440, 79581, 102157, 6173, 102157]
  CSV length=325, JSON length=325

User user_Go9F9fEh (encoded 3945): MISMATCH
  First mismatch at index 80:
    CSV context : [104867, 24814, 80232, 56284, 47671, 5516, 74695, 56284, 1920, 52264, 104867]
    JSON context: [104867, 24814, 80232, 56284, 47671, 74695, 5516, 56284, 1920, 52264, 104867]
  CSV length=356, JSON length=356

User user_amW83YPG (encoded 8396): OK
User 

In [None]:
print(df_listening_history[df_listening_history["user"] == 'user_99ag1aCc'].head(60))

                 user              song         timestamp
773053  user_99ag1aCc  lrL4RAKX6f8Tz2h9  2019-01-23 13:59
773054  user_99ag1aCc  gV9zey0BURekfwUz  2019-01-23 15:32
773055  user_99ag1aCc  TjyPWv8mZY4HyiHG  2019-01-23 15:36
773056  user_99ag1aCc  mqcCD4gTbcbp524p  2019-01-23 15:40
773057  user_99ag1aCc  QKtqOEayujT7uiN8  2019-01-23 15:46
773058  user_99ag1aCc  hiZOnY63Q67hWHMU  2019-01-23 18:27
773059  user_99ag1aCc  SF0M6cQhMuaeirB5  2019-01-23 18:31
773060  user_99ag1aCc  zDJheSwyuy4euMpX  2019-01-23 18:34
773061  user_99ag1aCc  bgaW0PAE9g3ndI3H  2019-01-23 20:17
773062  user_99ag1aCc  rfaYNZCE9bm2ERQI  2019-01-23 20:25
773063  user_99ag1aCc  X90FM1k6kIL08QcO  2019-01-23 20:29
773064  user_99ag1aCc  pzDReyjgfMr7mNTX  2019-01-23 20:33
773065  user_99ag1aCc  82gbxU2ARroqVp1r  2019-01-24 14:42
773066  user_99ag1aCc  BdpE84qRXtc9deX6  2019-01-24 14:47
773067  user_99ag1aCc  rwM0ld0uwxjMK9Fh  2019-01-24 14:51
773068  user_99ag1aCc  WHPR9OrFBOToYWXJ  2019-01-24 14:56
773069  user_9