In [1]:
from typing import Optional, Callable, List, Dict, Tuple
import os

import dgmc
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder

import torch
from torch import Tensor
import funcs
import pytorch_lightning as pl
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import torch_geometric


%reload_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
file_path = f"s3://drose-sandbox/sizmek_zync_1m"
df = pd.read_parquet(file_path)
print(f"Shape: {df.shape[0]:,} Memory: {df.memory_usage(deep=True).sum()/1e9:.2f}GB")

Shape: 1,000,000 Memory: 2.72GB


In [3]:
sizmek_cols = ["account_id", "url", "referrer_url", "city_code",
               "state_code", "dma_code", "country_code"]
zync_cols = ["session_id", "referrer", "client", "user_agent_platform",
             "user_agent_language", "user_agent_browser"]

sizmek_df = df[sizmek_cols]
zync_df = df[zync_cols]

In [4]:
sizmek_df.head(3)

Unnamed: 0,account_id,url,referrer_url,city_code,state_code,dma_code,country_code
0,19967,windstream.net,,5072006,NE,722,US
1,19967,windstream.net,,5072006,NE,722,US
2,35927,https://www.windstream.net/?inc=1176,https://www.windstream.net/?inc=1175,5072006,NE,722,US


In [5]:
zync_df.head(3)

Unnamed: 0,session_id,referrer,client,user_agent_platform,user_agent_language,user_agent_browser
0,31b423df-5602-4ccc-8983-6c7ab6f65e99:162431114...,https://www.windstream.net/?inc=532,sizmek,windows,,chrome
1,31b423df-5602-4ccc-8983-6c7ab6f65e99:162431114...,https://www.windstream.net/?inc=532,sizmek,windows,,chrome
2,31b423df-5602-4ccc-8983-6c7ab6f65e99:162431114...,https://www.windstream.net/?inc=532,sizmek,windows,,chrome


In [23]:
class ZetaDataset(torch_geometric.data.InMemoryDataset):
    def __init__(self, root: str, sizmek_df: pd.DataFrame, zync_df: pd.DataFrame, column: str, 
                 label: str, feature_cols=None, parse_url=False, expand_x=None, transform=None, 
                 pre_transform=None
        ):
        self.root = root
        self.sizmek_df = sizmek_df
        self.zync_df = zync_df
        self.column = column
        self.label = label
        self.feature_cols = feature_cols
        self.parse_url = parse_url
        self.expand_x = expand_x
        super(ZetaDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        print("Removing processed file. . .")
        os.remove(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        return [
            "sizmek_bidstream_raw_20210625_10k.csv", 
            "zync_session_tracking_orc_20210625_10k.csv"
        ]

    @property
    def processed_file_names(self):
        return ["ZetaDataset.pt"]

    def download(self):
        pass

    def process(self):
        x1, edge_index1 = self.process_graph(self.sizmek_df, self.column[0], self.feature_cols["sizmek"])
        x2, edge_index2 = self.process_graph(self.zync_df, self.column[1], self.feature_cols["zync"], 976)

        train_y = self.process_y()
        test_y = self.process_y()

        data = Data(x1=x1, edge_index1=edge_index1, x2=x2,
                    edge_index2=edge_index2, train_y=train_y,
                    test_y=test_y)
        torch.save(self.collate([data]), self.processed_paths[0])

    def process_graph(self, df, column: str, feature_cols: List, expand_x: int=None):
        print(f"Processing graph on {column}")
        #df = pd.read_csv(file_path, low_memory=False)
        #df.columns = [i.split(".")[1] for i in df.columns]

        # parse URLs
        if self.parse_url == True and column in ["url", "referrer"]:
            df[column] = df[column].apply(
                lambda x:urlparse(x).netloc if pd.notnull(x) else x
            )

        # Encode features
        feature_enc = OneHotEncoder(handle_unknown="ignore")
        features = pd.DataFrame(
            feature_enc.fit_transform(df[feature_cols]).toarray(), 
            columns=feature_enc.get_feature_names(feature_cols)
        )
        if expand_x is not None:
            print(f"Expanding X to {expand_x}")
            new_cols = [f"fake_{expand_x-i}" for i in range(expand_x - features.shape[1])][::-1]
            for col in new_cols:
                features[col] = 0
        x = torch.tensor(features.values, dtype=torch.float)
        
        edges = funcs.connect_edges(df, column)
        edge_index = torch.tensor(
            edges[['source','target']].T.values, dtype=torch.long
        )

        return x, edge_index

    def process_y(self) -> Tensor:
        y_1 = torch.tensor([range(0,2000)])[0]
        y_2 = torch.tensor([range(0,2000)])[0]
        train_y = torch.stack([y_1, y_2], dim=0)
        return train_y
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(sizmek/zync)"

In [24]:
class SumEmbedding(object):
    def __call__(self, data):
        data.x1, data.x2 = data.x1.sum(dim=1), data.x2.sum(dim=1)
        return data

In [25]:
sizmek_small = sizmek_df.sample(n=5_000, random_state=0)
zync_small = zync_df.iloc[sizmek_small.index,:]

In [26]:
feature_cols = {
    "sizmek": ["account_id", "referrer_url", "city_code","state_code", "dma_code", "country_code"],
    "zync": ["client", "user_agent_platform","user_agent_language", "user_agent_browser"]
}

zeta_data = ZetaDataset(
    root="./data/",
    sizmek_df = sizmek_small,
    zync_df = zync_small,
    column=["url", "referrer"],
    label=["zeta_user_id", "client_id"],
    feature_cols=feature_cols,
    parse_url=False,
    transform=SumEmbedding()
)
zeta_data

Processing...
Processing graph on url
Create dict for url
Processing graph on referrer
Expanding X to 976


  features[col] = 0


Create dict for referrer
Done!
Removing processed file. . .


ZetaDataset(sizmek/zync)

In [27]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dim', type=int, default=256)
parser.add_argument('--rnd_dim', type=int, default=32)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_steps', type=int, default=10)
parser.add_argument('--k', type=int, default=10)
args = parser.parse_args("")


psi_1 = dgmc.models.RelCNN(zeta_data.data.x1.size(-1), args.dim, args.num_layers, batch_norm=False,
               cat=True, lin=True, dropout=0.5)
psi_2 = dgmc.models.RelCNN(args.rnd_dim, args.rnd_dim, args.num_layers, batch_norm=False,
               cat=True, lin=True, dropout=0.0)

psi_1

RelCNN(976, 256, num_layers=3, batch_norm=False, cat=True, lin=True, dropout=0.5)

In [30]:
model = dgmc.models.DGMC(psi_1, psi_2, num_steps=None, k=args.k).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
data = zeta_data.data

def train():
    model.train()
    optimizer.zero_grad()

    _, S_L = model(data.x1, data.edge_index1, None, None, data.x2,
                   data.edge_index2, None, None, data.train_y)

    loss = model.loss(S_L, data.train_y)
    loss.backward()
    optimizer.step()
    return loss

@torch.no_grad()
def test():
    model.eval()

    _, S_L = model(data.x1, data.edge_index1, None, None, data.x2,
                   data.edge_index2, None, None)

    hits1 = model.acc(S_L, data.test_y)
    hits10 = model.hits_at_k(10, S_L, data.test_y)

    return hits1, hits10

In [None]:
model.num_steps = 0
with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs/gnn_v5"),
    with_stack=True
) as profiler:
    for epoch in range(1, 500):
        print("step:{}".format(epoch))
        loss = train()
        profiler.step()
        
        if epoch % 5 == 0:
            hits1, hits10 = test()
            print((f"{epoch:03d}: Loss: {loss:.4f}, Hits@1: {hits1:.4f}, "
               f"Hits@10: {hits10:.4f}"))

step:1
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
step:2
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
step:3
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
step:4
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
step:5
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
005: Loss: 2.8638, Hits@1: 0.0035, Hits@10: 0.0420
step:6
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
step:7
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
step:8
S: torch.Size([5000, 976]), torch.Size([2, 1118134])


In [29]:
print('Optimize initial feature matching...')
model.num_steps = 0
for epoch in range(1, 6):
    if epoch == 3:
        print('Refine correspondence matrix...')
        model.num_steps = args.num_steps
        model.detach = True

    loss = train()

    #if epoch % 10 == 0 or epoch > 100:
    if True:
        hits1, hits10 = test()
        print((f'{epoch:03d}: Loss: {loss:.4f}, Hits@1: {hits1:.4f}, '
               f'Hits@10: {hits10:.4f}'))

Optimize initial feature matching...
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
001: Loss: 3.0431, Hits@1: 0.0010, Hits@10: 0.0095
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
002: Loss: 3.0252, Hits@1: 0.0025, Hits@10: 0.0170
Refine correspondence matrix...
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
003: Loss: 3.0741, Hits@1: 0.0020, Hits@10: 0.0235
S: torch.Size([5000, 976]), torch.Size([2, 1118134])
T: torch.Size([5000, 976]), torch.Size([2, 431140])
S: torch.Size([5000, 976]), torch.Size([2, 