# DEXML - Dual-encoders for Extreme Multi-label Classification

## Download and install requirements (for colab)

In [None]:
!git clone https://github.com/nilesh2797/dexml

Cloning into 'dexml'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (37/37), done.[K
remote: Total 58 (delta 16), reused 51 (delta 13), pack-reused 0[K
Receiving objects: 100% (58/58), 47.27 KiB | 820.00 KiB/s, done.
Resolving deltas: 100% (16/16), done.


In [None]:
%cd dexml

/content/dexml


In [15]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


## Imports

In [5]:
import sys, os, time, socket, yaml, wandb, logging, numpy as np
import logging.config
import scipy.sparse as sp
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import matplotlib.pyplot as plt

In [6]:
from nets import TFEncoder
from datasets import DATA_MANAGERS
from utils.helper_utils import load_config_and_runtime_args, compute_xmc_metrics
from utils.nns_utils import ExactSearch

## Download and process datasets

In [4]:
import scipy.sparse as sp
def to_scipy_matrix(lol_inds, num_cols):
  cols = np.concatenate(lol_inds)
  rows = np.concatenate([[i]*len(x) for i, x in enumerate(lol_inds)])
  data = np.concatenate([[1.0]*len(x) for i, x in enumerate(lol_inds)])
  return sp.coo_matrix((data, (rows, cols)), (len(lol_inds), num_cols)).tocsr()

import pandas, gzip
def process_lf_amazon_datasets(dataset):
  print('Reading raw dataset files...')
  trn_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/trn.json.gz'), lines=True)
  tst_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/tst.json.gz'), lines=True)
  lbl_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/lbl.json.gz'), lines=True)

  print('Processing Y (label) files...')
  trn_X_Y = to_scipy_matrix(trn_df.target_ind.values, lbl_df.shape[0])
  tst_X_Y = to_scipy_matrix(tst_df.target_ind.values, lbl_df.shape[0])

  sp.save_npz(f'Datasets/{dataset}/Y.trn.npz', trn_X_Y)
  sp.save_npz(f'Datasets/{dataset}/Y.tst.npz', tst_X_Y)

  print('Processing X (input) files...')
  print(*trn_df.title.apply(lambda x: x.strip()).values, sep='\n', file=open(f'Datasets/{dataset}/raw/trn_X.txt', 'w'))
  print(*tst_df.title.apply(lambda x: x.strip()).values, sep='\n', file=open(f'Datasets/{dataset}/raw/tst_X.txt', 'w'))
  print(*lbl_df.title.apply(lambda x: x.strip()).values, sep='\n', file=open(f'Datasets/{dataset}/raw/Y.txt', 'w'))

  print('Tokenizing X (input) files...')
  max_len = 32 if 'titles' in dataset else 128
  os.system(f"python utils/tokenization_utils.py --data-path Datasets/{dataset}/raw/trn_X.txt --tf-max-len {max_len}")
  os.system(f"python utils/tokenization_utils.py --data-path Datasets/{dataset}/raw/tst_X.txt --tf-max-len {max_len}")
  os.system(f"python utils/tokenization_utils.py --data-path Datasets/{dataset}/raw/Y.txt --tf-max-len {max_len}")

### Eurlex-4K

In [25]:
!mkdir -p Datasets
!cd Datasets; gdown 1A_sL_mzpkmnr6g0DSZ0_xJTr4GN-rIfi; tar -xvzf Eurlex-4K.tar.gz; mv Eurlex-4K EURLex-4K
!cd Datasets/EURLex-4K; mkdir -p raw; mv train_raw_texts.txt raw/trn_X.txt; mv test_raw_texts.txt raw/tst_X.txt; mv label_map.txt raw/Y.txt

Downloading...
From (original): https://drive.google.com/uc?id=1A_sL_mzpkmnr6g0DSZ0_xJTr4GN-rIfi
From (redirected): https://drive.google.com/uc?id=1A_sL_mzpkmnr6g0DSZ0_xJTr4GN-rIfi&confirm=t&uuid=d57c53de-f627-40ea-96b9-cacb4b9dcc1b
To: /Users/thekop/Repo/dissertation/examples/DEXML/Datasets/Eurlex-4K.tar.gz
100%|████████████████████████████████████████| 157M/157M [00:14<00:00, 11.1MB/s]
x ./Eurlex-4K/
x ./Eurlex-4K/train_raw_texts.txt
x ./Eurlex-4K/X.tst.npz
x ./Eurlex-4K/X.tst.finetune.xlnet.npy
x ./Eurlex-4K/X.trn.npz
x ./Eurlex-4K/Y.trn.npz
x ./Eurlex-4K/Y.tst.npz
x ./Eurlex-4K/test_raw_texts.txt
x ./Eurlex-4K/X.trn.finetune.xlnet.npy
x ./Eurlex-4K/label_map.txt


In [26]:
!python utils/tokenization_utils.py --data-path Datasets/EURLex-4K/raw/trn_X.txt --tf-max-len 128
!python utils/tokenization_utils.py --data-path Datasets/EURLex-4K/raw/tst_X.txt --tf-max-len 128
!python utils/tokenization_utils.py --data-path Datasets/EURLex-4K/raw/Y.txt --tf-max-len 128

Read 15449 lines
Dumping tokenized file at Datasets/EURLex-4K/raw/trn_X.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 1/1 [00:11<00:00, 11.78s/it]
Finished tokenize_dump_memmap in 11.8053 secs
Read 3865 lines
Dumping tokenized file at Datasets/EURLex-4K/raw/tst_X.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.82s/it]
Finished tokenize_dump_memmap in 3.8402 secs
Read 3956 lines
Dumping tokenized file at Datasets/EURLex-4K/raw/Y.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/it]
Finished tokenize_dump_memmap in 1.3373 secs


### LF-AmazonTitles-131K

In [27]:
!mkdir -p Datasets
!cd Datasets; gdown 1WuquxCAg8D4lKr-eZXPv4nNw2S2lm7_E;
!cd Datasets; unzip LF-Amazon-131K.raw.zip; mv LF-Amazon-131K LF-AmazonTitles-131K; mkdir -p LF-AmazonTitles-131K/raw

Downloading...
From (original): https://drive.google.com/uc?id=1WuquxCAg8D4lKr-eZXPv4nNw2S2lm7_E
From (redirected): https://drive.google.com/uc?id=1WuquxCAg8D4lKr-eZXPv4nNw2S2lm7_E&confirm=t&uuid=9756dcb4-7a68-41a1-8c4b-5d70fae79a7c
To: /Users/thekop/Repo/dissertation/examples/DEXML/Datasets/LF-Amazon-131K.raw.zip
100%|████████████████████████████████████████| 245M/245M [00:30<00:00, 7.96MB/s]
Archive:  LF-Amazon-131K.raw.zip
   creating: LF-Amazon-131K/
  inflating: LF-Amazon-131K/lbl.json.gz  
  inflating: LF-Amazon-131K/trn.json.gz  
  inflating: LF-Amazon-131K/filter_labels_test.txt  
  inflating: LF-Amazon-131K/tst.json.gz  
  inflating: LF-Amazon-131K/filter_labels_train.txt  


In [28]:
process_lf_amazon_datasets('LF-AmazonTitles-131K')

Reading raw dataset files...
Processing Y (label) files...
Processing X (input) files...
Tokenizing X (input) files...
Read 294805 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-131K/raw/trn_X.bert-base-uncased_128.dat...


100%|██████████| 1/1 [00:05<00:00,  5.84s/it]


Finished tokenize_dump_memmap in 6.0009 secs
Read 134835 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-131K/raw/tst_X.bert-base-uncased_128.dat...


100%|██████████| 1/1 [00:02<00:00,  2.86s/it]


Finished tokenize_dump_memmap in 2.9192 secs
Read 131073 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-131K/raw/Y.bert-base-uncased_128.dat...


100%|██████████| 1/1 [00:02<00:00,  2.84s/it]


Finished tokenize_dump_memmap in 2.8952 secs


### LF-AmazonTitles-1.3M

In [29]:
!mkdir -p Datasets
!cd Datasets; gdown 12zH4mL2RX8iSvH0VCNnd3QxO4DzuHWnK;
!cd Datasets; unzip LF-Amazon-1.3M.raw.zip; mv LF-Amazon-1.3M LF-AmazonTitles-1.3M; mkdir -p LF-AmazonTitles-1.3M/raw

Downloading...
From (original): https://drive.google.com/uc?id=12zH4mL2RX8iSvH0VCNnd3QxO4DzuHWnK
From (redirected): https://drive.google.com/uc?id=12zH4mL2RX8iSvH0VCNnd3QxO4DzuHWnK&confirm=t&uuid=b7ae47aa-59e5-4602-9ea8-276065c636e9
To: /Users/thekop/Repo/dissertation/examples/DEXML/Datasets/LF-Amazon-1.3M.raw.zip
100%|████████████████████████████████████████| 890M/890M [01:38<00:00, 9.08MB/s]
Archive:  LF-Amazon-1.3M.raw.zip
   creating: LF-Amazon-1.3M/
  inflating: LF-Amazon-1.3M/lbl.json.gz  
  inflating: LF-Amazon-1.3M/trn.json.gz  
  inflating: LF-Amazon-1.3M/filter_labels_test.txt  
  inflating: LF-Amazon-1.3M/tst.json.gz  
  inflating: LF-Amazon-1.3M/filter_labels_train.txt  


In [30]:
process_lf_amazon_datasets('LF-AmazonTitles-1.3M')

Reading raw dataset files...
Processing Y (label) files...
Processing X (input) files...
Tokenizing X (input) files...
Read 2248619 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-1.3M/raw/trn_X.bert-base-uncased_128.dat...


100%|██████████| 5/5 [00:45<00:00,  9.11s/it]


Finished tokenize_dump_memmap in 46.0710 secs
Read 970237 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-1.3M/raw/tst_X.bert-base-uncased_128.dat...


100%|██████████| 2/2 [00:18<00:00,  9.36s/it]


Finished tokenize_dump_memmap in 19.1293 secs
Read 1305265 lines
Dumping tokenized file at Datasets/LF-AmazonTitles-1.3M/raw/Y.bert-base-uncased_128.dat...


100%|██████████| 3/3 [00:25<00:00,  8.65s/it]


Finished tokenize_dump_memmap in 26.1658 secs


### LF-Wikipedia-500K

In [33]:
!mkdir -p Datasets/LF-Wikipedia-500K
!cd Datasets/LF-Wikipedia-500K; gdown 10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K; gdown 1pEyKXtkwHhinuRxmARhtwEQ39VIughDf; gdown 1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3; mkdir -p raw

Downloading...
From (original): https://drive.google.com/uc?id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K
From (redirected): https://drive.google.com/uc?id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K&confirm=t&uuid=50e16e32-fb16-45d7-881d-9f73075f6c63
To: /Users/thekop/Repo/dissertation/examples/DEXML/Datasets/LF-Wikipedia-500K/trn.raw.json.gz
100%|██████████████████████████████████████| 5.29G/5.29G [08:06<00:00, 10.9MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf
From (redirected): https://drive.google.com/uc?id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf&confirm=t&uuid=52e866b1-9338-4dc8-b598-3871eb9d7509
To: /Users/thekop/Repo/dissertation/examples/DEXML/Datasets/LF-Wikipedia-500K/tst.raw.json.gz
100%|██████████████████████████████████████| 2.30G/2.30G [03:29<00:00, 11.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3
To: /Users/thekop/Repo/dissertation/examples/DEXML/Datasets/LF-Wikipedia-500K/Yf.txt
100%|████████

In [2]:
import pandas, gzip
import numpy as np
import re
def preprocess(s):
    return ' '.join([x for x in re.split(r"[^a-zA-Z0-9.]", s.lower()) if len(x) > 0])

In [7]:
dataset = 'LF-Wikipedia-500K'

print('Reading raw dataset files...')
trn_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/trn.raw.json.gz'), lines=True)
tst_df = pandas.read_json(gzip.open(f'Datasets/{dataset}/tst.raw.json.gz'), lines=True)
Y = [preprocess(x.strip().split('->')[-1]) for x in open('Datasets/LF-Wikipedia-500K/Yf.txt')]

print('Processing Y (label) files...')
trn_X_Y = to_scipy_matrix(trn_df.target_ind.values, len(Y))
tst_X_Y = to_scipy_matrix(tst_df.target_ind.values, len(Y))
sp.save_npz(f'Datasets/{dataset}/Y.trn.npz', trn_X_Y)
sp.save_npz(f'Datasets/{dataset}/Y.tst.npz', tst_X_Y)

Reading raw dataset files...
Processing Y (label) files...


In [8]:
trn_df['text'] = trn_df.title + " " + trn_df.title + " " + trn_df.content
tst_df['text'] = tst_df.title + " " + tst_df.title + " " + tst_df.content
trn_df['text'] = trn_df['text'].apply(preprocess)
tst_df['text'] = tst_df['text'].apply(preprocess)

print('Processing X (input) files...')
print(*trn_df.text.values, sep='\n', file=open(f'Datasets/{dataset}/raw/trn_X.txt', 'w'))
print(*tst_df.text.values, sep='\n', file=open(f'Datasets/{dataset}/raw/tst_X.txt', 'w'))
print(*Y, sep='\n', file=open(f'Datasets/{dataset}/raw/Y.txt', 'w'))

In [117]:
print('Tokenizing X (input) files...')
!python utils/tokenization_utils.py --data-path Datasets/LF-Wikipedia-500K/raw/trn_X.txt --tf-max-len 128
!python utils/tokenization_utils.py --data-path Datasets/LF-Wikipedia-500K/raw/tst_X.txt --tf-max-len 128
!python utils/tokenization_utils.py --data-path Datasets/LF-Wikipedia-500K/raw/Y.txt --tf-max-len 128

Tokenizing X (input) files...


Read 1813391 lines
Dumping tokenized file at Datasets/LF-Wikipedia-500K/raw/trn_X.bert-base-uncased_128.dat...
100%|████████████████████████████████████████████| 4/4 [31:57<00:00, 479.45s/it]
Finished tokenize_dump_memmap in 1919.5861 secs
Read 783743 lines
Dumping tokenized file at Datasets/LF-Wikipedia-500K/raw/tst_X.bert-base-uncased_128.dat...
100%|████████████████████████████████████████████| 2/2 [14:59<00:00, 449.79s/it]
Finished tokenize_dump_memmap in 900.0780 secs
Read 501070 lines
Dumping tokenized file at Datasets/LF-Wikipedia-500K/raw/Y.bert-base-uncased_128.dat...
100%|█████████████████████████████████████████████| 2/2 [00:09<00:00,  5.00s/it]
Finished tokenize_dump_memmap in 10.4635 secs


## Run DEXML

In [118]:
dataset = 'LF-Wikipedia-500K'

In [121]:
args = load_config_and_runtime_args(['', f'configs/{dataset}/dist-de-all_decoupled-softmax.yaml'])
args.DATA_DIR = f'Datasets/{args.dataset}'
args.OUT_DIR = f'Results/{args.dataset}/demo'
os.makedirs(args.OUT_DIR, exist_ok=True)

In [122]:
data_manager = DATA_MANAGERS[args.data_manager](args)
trn_loader, val_loader, _ = data_manager.build_data_loaders()

neg_type: none


In [127]:
args.tf = f'quicktensor/dexml_{args.dataset.lower()}'
net = TFEncoder(args)

if torch.cuda.is_available():
  net.to('cuda:0');

In [128]:
# reduce batch-size if OOM
tst_embs = net.get_embs(val_loader.dataset.x_dataset, bsz=1024)
lbl_embs = net.get_embs(val_loader.dataset.y_dataset, bsz=1024)

Embedding: 100%|██████████████████████████████████████████████████████████████████████████████| 766/766 [01:47<00:00,  7.14it/s]
Embedding: 100%|██████████████████████████████████████████████████████████████████████████████| 490/490 [00:10<00:00, 46.65it/s]


In [129]:
es = ExactSearch(lbl_embs, K=100, device='cuda:0')
score_mat = es.search(tst_embs)

Searching in 1 shards


Searching shard 1/1: 100%|██████████████████████████████████████████████████████████████████| 1531/1531 [01:00<00:00, 25.46it/s]


Total time, time per point : 61.91s, 0.0790 ms/pt


In [130]:
from utils.helper_utils import load_filter_mat, _filter
if os.path.exists(f'Datasets/{args.dataset}/filter_labels_test.txt'):
    print('Filtering predictions...')
    filter_mat = load_filter_mat(f'Datasets/{args.dataset}/filter_labels_test.txt', val_loader.dataset.labels.shape)
    _filter(score_mat, filter_mat, copy=False);

In [131]:
compute_xmc_metrics(score_mat, val_loader.dataset.labels, inv_prop=None, disp=False)

Unnamed: 0,P@1,P@3,P@5,nDCG@1,nDCG@3,nDCG@5,MRR@10,R@10,R@50,R@100
Method,85.59,65.56,50.39,85.59,79.4,76.89,89.91,75.42,87.18,90.52
