In [46]:
import os
import sys
import math
import pickle

import numpy as np
import pandas as pd
from typing import Tuple
import data_helper
import importlib
from einops import rearrange
from collections import defaultdict
from itertools import product
from tqdm.notebook import tqdm

from loguru import logger
logger.remove()
logger.add(sys.stdout, colorize=True, format="<blue>{time}</blue> <level>{message}</level>")
logger.level("INFO", color="<red><bold>")

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import init
from torchviz import make_dot
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [47]:
class SkipGramModel(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.target_embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.output = nn.Linear(embedding_dim, num_embeddings)

        initrange = 1.0 / self.embedding_dim
        init.uniform_(self.target_embeddings.weight.data, -
                      initrange, initrange)

    def forward(self, target, context):
        emb_target = self.target_embeddings(target)
        # print('current embedding: ', testskipgrammodel.target_embeddings.weight)

        score = self.output(emb_target)
        # print('score: ', score)
        score = F.log_softmax(score, dim=-1)
        # print('softmax score: ', score)

        losses = torch.stack([F.nll_loss(score, context_word)
                              for context_word in context.transpose(0, 1)])
        return losses.mean()

In [48]:
class Block2VecDataset(Dataset):
    def __init__(self, neighbor_radius = 1, block_ids_table_path='block_ids_alt.tsv'):
        super().__init__()
        self.neighbor_radius = neighbor_radius
        self.block_ids_table_path = block_ids_table_path
        
        self._gen_block_id_lookup_dict()
        self._read_blocks()
        
        padding = 2 * self.neighbor_radius  # one token on each side
        self.x_dim = self.x_lims[1] - self.x_lims[0] + 1 - padding
        self.y_dim = self.y_lims[1] - self.y_lims[0] + 1 - padding
        self.z_dim = self.z_lims[1] - self.z_lims[0] + 1 - padding

        # assuming 32*32*32
        # print(self._idx_to_coords(0)) # lower bound
        # print(self._idx_to_coords(1799)) # should be 2, 30, 30
        # print(self._idx_to_coords(26999)) # upper bound
        
        # print(self._get_neighbors(23,5,5, neighbor_radius=self.neighbor_radius))
        
        # print(self._getitem(3))
        # print(self._getitem(35))
        # print(self._getitem(355))
    
    def _read_size(self, neighbor_radius=1):
        return [neighbor_radius, self.world.shape[0] - neighbor_radius - 1], [neighbor_radius, self.world.shape[1] - neighbor_radius - 1], [neighbor_radius, self.world.shape[2] - neighbor_radius - 1]
    
    def _gen_block_id_lookup_dict(self):
        mc_block_database = pd.read_csv('block_ids_alt.tsv', sep='\t')
        mc_block_database = mc_block_database.filter(items=['numerical id', 'item id'])
        mc_block_database = mc_block_database.dropna(subset=["numerical id"])
        mc_block_database
        self.block_id_lookup_dict = mc_block_database.set_index('numerical id').to_dict()['item id']
    
    def _read_blocks(self):
        self.world = data_helper.all_trainx_as_df()[:1] # only take first 100 for now
        self.world = self.world['world']
        self.world = np.stack(self.world, axis=0)
        self.world = rearrange(self.world, 'n x y z b -> (n x) y z b')
        self.world = self.world[:,:,:,0]
        logger.info(f"Loaded in world with shape: {self.world.shape}")
        
        self.x_lims, self.y_lims, self.z_lims = self._read_size(neighbor_radius=0)
        
        self.block_frequency = defaultdict(int)
        coordinates_to_track = self._gen_coords(*(self._read_size(neighbor_radius=0)))
        logger.info("Collecting {} blocks for frequency calculation", len(coordinates_to_track))
        for coord in tqdm(coordinates_to_track):
            numerical_id = self._get_block(coord[0], coord[1], coord[2])
            # treating all meta of same id the same for simplicity
                
            item_id = self.block_id_lookup_dict[str(numerical_id)]
                
            self.block_frequency[item_id] += 1

        logger.info("Found {len(self.block_frequency)} unique blocks")
        self.block2idx = dict()
        self.idx2block = dict()
        for name, count in self.block_frequency.items():
            block_idx = len(self.block2idx)
            self.block2idx[name] = block_idx
            self.idx2block[block_idx] = name
        logger.info("idx2block and block2idx dictionaries generated")
        # print(self.block2idx)
        # print(self.idx2block)
        
    def _get_block(self, x, y, z):
        # returns the id for the block
        return self.world[x][y][z]

    def _get_neighbors(self, x, y, z, neighbor_radius=1):
        neighbor_coords = [(x + x_diff, y + y_diff, z + z_diff) for x_diff, y_diff, z_diff in product(list(range(-neighbor_radius, neighbor_radius + 1)), repeat=3) if x_diff != 0 or y_diff != 0 or z_diff != 0]
        return [self._get_block(*coord) for coord in neighbor_coords]

    def _gen_coords(self, x_lims, y_lims, z_lims):
        return [(x, y, z) for x, y, z in product(range(x_lims[0], x_lims[1] + 1), range(y_lims[0], y_lims[1] + 1), range(z_lims[0], z_lims[1] + 1))]
    
    def _idx_to_coords(self, index):
        z = index % (self.z_dim)
        y = int(((index - z) / (self.z_dim)) % (self.y_dim))
        x = int(((index - z) / (self.z_dim) - y) / (self.y_dim))
        x += self.x_lims[0] + self.neighbor_radius
        y += self.y_lims[0] + self.neighbor_radius
        z += self.z_lims[0] + self.neighbor_radius
        return x, y, z
    
    def __len__(self):
        return self.x_dim * self.y_dim * self.z_dim
    
    def _getitem(self, index):
        coords = self._idx_to_coords(index)
        numerical_id = self._get_block(*coords)
        item_id = self.block_id_lookup_dict[str(numerical_id)]
        target = self.block2idx[item_id]
        target = torch.tensor(int(target))
        
        neighbors = self._get_neighbors(*coords)
        item_ids = [self.block_id_lookup_dict[str(numerical_id)] for numerical_id in neighbors]
        context = [self.block2idx[item_id] for item_id in item_ids]
        
        context = torch.tensor(context)
        return target, context
    
    def __getitem__(self, index):
        return self._getitem(index)

In [49]:
class Block2Vec(pl.LightningModule):
    def __init__(self, embedding_dim = 32, initial_lr = 1e-3, neighbor_radius = 1, batch_size = 256, num_epochs = 30):
        super().__init__()
        self.save_hyperparameters() # making lightning save params under self.hparams
        
        self.dataset = Block2VecDataset(neighbor_radius)
        
        self.embedding_dim = embedding_dim
        self.learning_rate = initial_lr # initial learning rate
        self.neighbor_radius = neighbor_radius
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        
        self.num_embeddings = len(self.dataset.block2idx)
        self.model = SkipGramModel(self.num_embeddings, self.embedding_dim)
        self.textures = dict()
        
        print(self.model)
        
    def forward(self, target, context) -> torch.Tensor:
        return self.model(target, context)

    def training_step(self, batch):
        loss = self.forward(*batch)
        self.log("loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            math.ceil(len(self.dataset) / self.batch_size) *
            self.num_epochs,
        )
        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
        )
        
    def on_epoch_end(self):
        print("yay!")

In [52]:
def main():
    block2vec = Block2Vec()
    trainer = pl.Trainer(gpus=0, max_epochs=10, fast_dev_run=False)
    trainer.fit(block2vec)

In [53]:
main()

  0%|          | 0/1977 [00:00<?, ?it/s]

loaded 1977 houses
[34m2022-02-13T00:39:16.526929+0800[0m [31m[1mLoaded in world with shape: (32, 32, 32)[0m
[34m2022-02-13T00:39:16.530448+0800[0m [31m[1mCollecting 32768 blocks for frequency calculation[0m


  0%|          | 0/32768 [00:00<?, ?it/s]

[34m2022-02-13T00:39:16.590568+0800[0m [31m[1mFound {len(self.block_frequency)} unique blocks[0m
[34m2022-02-13T00:39:16.591022+0800[0m [31m[1midx2block and block2idx dictionaries generated[0m


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type          | Params
----------------------------------------
0 | model | SkipGramModel | 260   
----------------------------------------
260       Trainable params
0         Non-trainable params
260       Total params
0.001     Total estimated model params size (MB)
  rank_zero_warn(


SkipGramModel(
  (target_embeddings): Embedding(4, 32)
  (output): Linear(in_features=32, out_features=4, bias=True)
)
Epoch 0: 100%|██████████| 106/106 [00:02<00:00, 42.82it/s, loss=0.735, v_num=12]yay!
Epoch 1: 100%|██████████| 106/106 [00:01<00:00, 61.80it/s, loss=0.161, v_num=12]yay!
Epoch 2: 100%|██████████| 106/106 [00:02<00:00, 48.93it/s, loss=0.107, v_num=12]yay!
Epoch 3: 100%|██████████| 106/106 [00:02<00:00, 48.95it/s, loss=0.104, v_num=12]yay!
Epoch 4: 100%|██████████| 106/106 [00:02<00:00, 49.62it/s, loss=0.0956, v_num=12]yay!
Epoch 5: 100%|██████████| 106/106 [00:02<00:00, 42.19it/s, loss=0.0938, v_num=12]yay!
Epoch 6: 100%|██████████| 106/106 [00:02<00:00, 43.55it/s, loss=0.0901, v_num=12]yay!
Epoch 7: 100%|██████████| 106/106 [00:02<00:00, 42.64it/s, loss=0.0888, v_num=12]yay!
Epoch 8:   1%|          | 1/106 [00:00<00:03, 28.13it/s, loss=0.0891, v_num=12]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
