# DEXML - Dual-encoders for Extreme Multi-label Classification

## Download and install requirements (for colab)

In [None]:
!git clone https://github.com/nilesh2797/dexml

Cloning into 'dexml'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (37/37), done.[K
remote: Total 58 (delta 16), reused 51 (delta 13), pack-reused 0[K
Receiving objects: 100% (58/58), 47.27 KiB | 820.00 KiB/s, done.
Resolving deltas: 100% (16/16), done.


In [None]:
%cd dexml

/content/dexml


In [None]:
!pip install -r requirements.txt

## Imports

In [1]:
import sys, os, time, socket, yaml, wandb, logging, numpy as np
import logging.config
import scipy.sparse as sp
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import matplotlib.pyplot as plt

In [2]:
from nets import TFEncoder
from datasets import DATA_MANAGERS
from utils.helper_utils import load_config_and_runtime_args, compute_xmc_metrics
from utils.nns_utils import ExactSearch

## Download and process datasets

In [3]:
import scipy.sparse as sp
def to_scipy_matrix(lol_inds, num_cols):
  cols = np.concatenate(lol_inds)
  rows = np.concatenate([[i]*len(x) for i, x in enumerate(lol_inds)])
  data = np.concatenate([[1.0]*len(x) for i, x in enumerate(lol_inds)])
  return sp.coo_matrix((data, (rows, cols)), (len(lol_inds), num_cols)).tocsr()

import pandas, gzip
def process_lf_amazon_datasets(dataset):
  print('Reading raw dataset files...')
  trn_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/trn.json.gz'), lines=True)
  tst_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/tst.json.gz'), lines=True)
  lbl_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/lbl.json.gz'), lines=True)

  print('Processing Y (label) files...')
  trn_X_Y = to_scipy_matrix(trn_df.target_ind.values, lbl_df.shape[0])
  tst_X_Y = to_scipy_matrix(tst_df.target_ind.values, lbl_df.shape[0])

  sp.save_npz(f'Datasets/{dataset}/Y.trn.npz', trn_X_Y)
  sp.save_npz(f'Datasets/{dataset}/Y.tst.npz', tst_X_Y)

  print('Processing X (input) files...')
  print(*trn_df.title.apply(lambda x: x.strip()).values, sep='\n', file=open(f'Datasets/{dataset}/raw/trn_X.txt', 'w'))
  print(*tst_df.title.apply(lambda x: x.strip()).values, sep='\n', file=open(f'Datasets/{dataset}/raw/tst_X.txt', 'w'))
  print(*lbl_df.title.apply(lambda x: x.strip()).values, sep='\n', file=open(f'Datasets/{dataset}/raw/Y.txt', 'w'))

  print('Tokenizing X (input) files...')
  max_len = 32 if 'titles' in dataset else 128
  os.system(f"python utils/tokenization_utils.py --data-path Datasets/{dataset}/raw/trn_X.txt --tf-max-len {max_len}")
  os.system(f"python utils/tokenization_utils.py --data-path Datasets/{dataset}/raw/tst_X.txt --tf-max-len {max_len}")
  os.system(f"python utils/tokenization_utils.py --data-path Datasets/{dataset}/raw/Y.txt --tf-max-len {max_len}")

### Eurlex-4K

In [4]:
!mkdir -p Datasets
!cd Datasets; gdown 1A_sL_mzpkmnr6g0DSZ0_xJTr4GN-rIfi; tar -xvzf Eurlex-4K.tar.gz; mv Eurlex-4K EURLex-4K
!cd Datasets/EURLex-4K; mkdir -p raw; mv train_raw_texts.txt raw/trn_X.txt; mv test_raw_texts.txt raw/tst_X.txt; mv label_map.txt raw/Y.txt

Downloading...
From (original): https://drive.google.com/uc?id=1A_sL_mzpkmnr6g0DSZ0_xJTr4GN-rIfi
From (redirected): https://drive.google.com/uc?id=1A_sL_mzpkmnr6g0DSZ0_xJTr4GN-rIfi&confirm=t&uuid=791d63b2-38dd-483f-9fb5-09ec0a1faefe
To: /home/nilesh/work/DEXML/Datasets/Eurlex-4K.tar.gz
100%|████████████████████████████████████████| 157M/157M [00:01<00:00, 94.7MB/s]
./Eurlex-4K/
./Eurlex-4K/train_raw_texts.txt
./Eurlex-4K/X.tst.npz
./Eurlex-4K/X.tst.finetune.xlnet.npy
./Eurlex-4K/X.trn.npz
./Eurlex-4K/Y.trn.npz
./Eurlex-4K/Y.tst.npz
./Eurlex-4K/test_raw_texts.txt
./Eurlex-4K/X.trn.finetune.xlnet.npy
./Eurlex-4K/label_map.txt


In [5]:
!python utils/tokenization_utils.py --data-path Datasets/EURLex-4K/raw/trn_X.txt --tf-max-len 128
!python utils/tokenization_utils.py --data-path Datasets/EURLex-4K/raw/tst_X.txt --tf-max-len 128
!python utils/tokenization_utils.py --data-path Datasets/EURLex-4K/raw/Y.txt --tf-max-len 128

Read 15449 lines
Dumping tokenized file at Datasets/EURLex-4K/raw/trn_X.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 1/1 [00:14<00:00, 14.07s/it]
Finished tokenize_dump_memmap in 14.0789 secs
Read 3865 lines
Dumping tokenized file at Datasets/EURLex-4K/raw/tst_X.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.89s/it]
Finished tokenize_dump_memmap in 3.8926 secs
Read 3956 lines
Dumping tokenized file at Datasets/EURLex-4K/raw/Y.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  4.59it/s]
Finished tokenize_dump_memmap in 0.2221 secs


### LF-AmazonTitles-131K

In [9]:
!mkdir -p Datasets
!cd Datasets; gdown 1WuquxCAg8D4lKr-eZXPv4nNw2S2lm7_E;
!cd Datasets; unzip LF-Amazon-131K.raw.zip; mv LF-Amazon-131K LF-AmazonTitles-131K; mkdir -p LF-AmazonTitles-131K/raw

Downloading...
From (original): https://drive.google.com/uc?id=1WuquxCAg8D4lKr-eZXPv4nNw2S2lm7_E
From (redirected): https://drive.google.com/uc?id=1WuquxCAg8D4lKr-eZXPv4nNw2S2lm7_E&confirm=t&uuid=a28892c4-78c3-4762-a294-caec34754b78
To: /home/nilesh/work/DEXML/Datasets/LF-Amazon-131K.raw.zip
100%|████████████████████████████████████████| 245M/245M [00:02<00:00, 94.6MB/s]
Archive:  LF-Amazon-131K.raw.zip
   creating: LF-Amazon-131K/
  inflating: LF-Amazon-131K/lbl.json.gz  
  inflating: LF-Amazon-131K/trn.json.gz  
  inflating: LF-Amazon-131K/filter_labels_test.txt  
  inflating: LF-Amazon-131K/tst.json.gz  
  inflating: LF-Amazon-131K/filter_labels_train.txt  


In [None]:
process_lf_amazon_datasets('LF-AmazonTitles-1.3M')

### LF-AmazonTitles-1.3M

In [21]:
!mkdir -p Datasets
!cd Datasets; gdown 12zH4mL2RX8iSvH0VCNnd3QxO4DzuHWnK;
!cd Datasets; unzip LF-Amazon-1.3M.raw.zip; mv LF-Amazon-1.3M LF-AmazonTitles-1.3M; mkdir -p LF-AmazonTitles-1.3M/raw

Downloading...
From (original): https://drive.google.com/uc?id=12zH4mL2RX8iSvH0VCNnd3QxO4DzuHWnK
From (redirected): https://drive.google.com/uc?id=12zH4mL2RX8iSvH0VCNnd3QxO4DzuHWnK&confirm=t&uuid=391cf505-02cf-4591-8a96-37f829f9328e
To: /home/nilesh/work/DEXML/Datasets/LF-Amazon-1.3M.raw.zip
100%|████████████████████████████████████████| 890M/890M [00:13<00:00, 67.8MB/s]
Archive:  LF-Amazon-1.3M.raw.zip
   creating: LF-Amazon-1.3M/
  inflating: LF-Amazon-1.3M/lbl.json.gz  
  inflating: LF-Amazon-1.3M/trn.json.gz  
  inflating: LF-Amazon-1.3M/filter_labels_test.txt  
  inflating: LF-Amazon-1.3M/tst.json.gz  
  inflating: LF-Amazon-1.3M/filter_labels_train.txt  


In [22]:
process_lf_amazon_datasets('LF-AmazonTitles-1.3M')

Read 2248619 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-1.3M/raw/trn_X.bert-base-uncased_32.dat...


100%|██████████| 5/5 [00:41<00:00,  8.34s/it]


Finished tokenize_dump_memmap in 41.7673 secs
Read 970237 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-1.3M/raw/tst_X.bert-base-uncased_32.dat...


100%|██████████| 2/2 [00:15<00:00,  7.57s/it]


Finished tokenize_dump_memmap in 15.1589 secs
Read 1305265 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-1.3M/raw/Y.bert-base-uncased_32.dat...


100%|██████████| 3/3 [00:20<00:00,  6.89s/it]


Finished tokenize_dump_memmap in 20.7157 secs


## Run DEXML

In [15]:
dataset = 'LF-AmazonTitles-1.3M'

In [16]:
args = load_config_and_runtime_args(['', f'configs/{dataset}/dist-de-all_decoupled-softmax.yaml'])
args.DATA_DIR = f'Datasets/{args.dataset}'
args.OUT_DIR = f'Results/{args.dataset}/demo'
os.makedirs(args.OUT_DIR, exist_ok=True)

In [17]:
data_manager = DATA_MANAGERS[args.data_manager](args)
trn_loader, val_loader, _ = data_manager.build_data_loaders()

neg_type: none


In [18]:
args.tf = f'quicktensor/dexml_{args.dataset.lower()}'
net = TFEncoder(args)

if torch.cuda.is_available():
  net.to('cuda');

config.json:   0%|          | 0.00/546 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

In [30]:
# reduce batch-size if OOM
tst_embs = net.get_embs(val_loader.dataset.x_dataset, bsz=1024)
lbl_embs = net.get_embs(val_loader.dataset.y_dataset, bsz=1024)

In [28]:
es = ExactSearch(lbl_embs, K=100, device='cuda')
score_mat = es.search(tst_embs)

Searching in 2 shards


Searching shard 1/2: 100%|██████████| 1895/1895 [02:19<00:00, 13.55it/s]
Searching shard 2/2: 100%|██████████| 1895/1895 [00:46<00:00, 40.36it/s]
Aggregating sharded results: 100%|██████████| 10/10 [00:01<00:00,  5.26it/s]


Total time, time per point : 196.68s, 0.2027 ms/pt


In [29]:
from utils.helper_utils import load_filter_mat, _filter
if os.path.exists(f'Datasets/{args.dataset}/filter_labels_test.txt'):
    print('Filtering predictions...')
    filter_mat = load_filter_mat(f'Datasets/{args.dataset}/filter_labels_test.txt', val_loader.dataset.labels.shape)
    _filter(score_mat, filter_mat, copy=False);

Filtering predictions...


  self._set_arrayXarray(i, j, x)


In [30]:
compute_xmc_metrics(score_mat, val_loader.dataset.labels, inv_prop=None, disp=False)

Unnamed: 0,P@1,P@3,P@5,nDCG@1,nDCG@3,nDCG@5,MRR@10,R@10,R@50,R@100
Method,58.42,50.83,45.48,58.42,55.81,54.32,65.88,36.48,57.68,64.26
