In [1]:
import wandb
from datetime import date

In [2]:
wandb.init(
  project = 'ContrastiveClustering',
  notes = f'experiment-{date.today()}', 
  group ='cp',
  name = 'exp1'
)

[34m[1mwandb[0m: Currently logged in as: [33mdawon[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# General purposes
import numpy as np
import pandas as pd
from dotmap import DotMap

import torch
import torch.nn as nn
import torch.optim as optim

# Tensor decomposition
import tensorly as tl
from tensorly.cp_tensor import cp_to_tensor
from tensorly import check_random_state

# Plot
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
import torch.nn as nn

In [5]:
# Set devices, random sate, etc
random_state = 1234
rng = check_random_state(random_state)
device = 'cuda:2'
tl.set_backend('pytorch')

In [6]:
config = DotMap()
config.data = 'ml'
config.method = 'cp'

### Step 0: Prepare sparse tensor as COO type

In [7]:
# Read train, validation, and test set from data directory
path = f'/home/dahn017/courses/cs235/data/{config.data}'
train_indices = np.load(f'{path}/train_indices.npy')
train_vals = np.load(f'{path}/train_values.npy')

valid_indices = np.load(f'{path}/valid_indices.npy')
valid_vals = np.load(f'{path}/valid_values.npy')

test_indices = np.load(f'{path}/test_indices.npy')
test_vals = np.load(f'{path}/test_values.npy')

In [8]:
# Make tensors in to COO types (to consider only nonzeros)
train_i = torch.LongTensor(train_indices).to(device)
train_v = torch.FloatTensor(train_vals).reshape(-1).to(device)
stensor = torch.sparse.FloatTensor(train_i.t(), train_v.t()).coalesce()

valid_i = torch.LongTensor(valid_indices).to(device)
valid_v = torch.FloatTensor(valid_vals).reshape(-1).to(device)

test_i = torch.LongTensor(test_indices).to(device)
test_v = torch.FloatTensor(test_vals).reshape(-1).to(device)

### Decompose tensors

In [9]:
def krprod(indices, factors):
    ''' Implement a Khatri Rao Product with nonzeros'''
    
    rank = factors[0].shape[-1] # dim x rank
    nnz, _ = indices.shape # nnz x nmode
    
    # Compute the Khatri-Rao product for the chosen indices
    sampled_kr = torch.ones((nnz, rank)).to(device)  # nnz x rank
    for idx, factor in zip(indices.t(), factors): # nnz idx for each mode
        sampled_kr = sampled_kr*factor[idx.data]

    return sampled_kr.sum(1) # for each nonzero

In [10]:
# Hyper-parameter setting
rank = 30
n_iter = 10000
lr = 1e-3
penalty = 1e-2
clusterk = 10

config.rank = rank
config.n_iter = n_iter
config.lr = lr
config.penalty = penalty
config.clusterk = clusterk

In [20]:
# Initialize factor matrices
nmodes = stensor.size()
factors = [tl.tensor(rng.random_sample((i, rank)),
            device=device, requires_grad=True) for i in nmodes]
print(nmodes)

torch.Size([610, 9724, 4110])


In [21]:
# Optimizer to learn CPD
opt = optim.Adam(factors, lr=lr)

In [22]:
# Number of mode
nmode = 3

In [26]:
# Learn global information via CPD
old_val_error =1e+5
for i in range(1, n_iter):    
    opt.zero_grad()
    # Tensor rec loss
    rec = krprod(train_i, factors)
    loss = tl.norm(rec - train_v, 2)
    for f in factors:
        loss = loss + penalty * tl.norm(f, 2)
    
    loss.backward()
    opt.step()
    with torch.no_grad():
        if i % 10 == 0:
            rec_error = tl.norm(rec.data - train_v, 2) / tl.norm(train_v, 2)
            val_rec = krprod(valid_i, factors)
            val_error = tl.norm(val_rec.data - valid_v, 2) / tl.norm(valid_v, 2)
            # wandb.log({'train_rec_error': rec_error, 'val_rec_error': val_error})
            print(f"Iters {i}, Rec. error: {rec_error:.4f} Valid Rec. error: {val_error:.4f}")
            if val_error >  old_val_error:
                break
            old_val_error = val_error

Iters 10, Rec. error: 0.1986 Valid Rec. error: 0.2626
Iters 20, Rec. error: 0.1952 Valid Rec. error: 0.2630


In [25]:
with torch.no_grad():
    test_rec =  krprod(test_i, factors)
test_nre = tl.norm(test_rec.data - test_v, 2) / tl.norm(test_v, 2)
test_rmse = tl.sqrt(tl.mean((test_rec.data - test_v) ** 2))
print(f"Test NRE : {test_nre:.4f} Test RMSE: {test_rmse:.4f}")

# wandb.log({'test_nre': test_nre, 'test_rmse': test_rmse})
# wandb.finish()

Test NRE : 0.2615 Test RMSE: 0.9550
