In [1]:
import os
from sentence_transformers import SentenceTransformer, losses, BiSentenceTransformer
from sentence_transformers.readers import STSDataReader, FEVERReader
from sentence_transformers.datasets import SentencesDataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from torch.utils.data import DataLoader
import psycopg2
import time
import numpy as np

## Training with Siamese Model

In [2]:
base_model = SentenceTransformer('bert-base-nli-stsb-mean-tokens')
train_batch_size = 16
num_epochs = 1
warmup_steps=100
model_save_path='./fever-model'

In [3]:
reader = FEVERReader()
train_examples = reader.get_examples('train',table='test.train_article_rerank')
train_data = SentencesDataset(train_examples, base_model)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=base_model)

dev_examples = reader.get_examples('dev',table='test.test_article_rerank')
dev_data = SentencesDataset(examples=dev_examples, model=base_model)
dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=train_batch_size)
evaluator = EmbeddingSimilarityEvaluator(dev_dataloader)
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          fp16=True,
          fp16_opt_level='O1'
)

trying to connect to postgres...
connected to postgres
downloading data
trying to connect to postgres...
connected to postgres
downloading data


NameError: name 'model' is not defined

## Training with Modified Siamese (dense feedforward layer for query)

In [4]:
model = BiSentenceTransformer(base_model)
train_batch_size = 16
num_epochs = 1
warmup_steps=100
model_save_path='./modified-fever'

In [4]:
reader = FEVERReader()
train_examples = reader.get_examples('train',table='test.train_article_rerank')
train_data = SentencesDataset(train_examples, base_model)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.BiCosineSimilarityLoss(model=model)

dev_examples = reader.get_examples('dev',table='test.test_article_rerank')
dev_data = SentencesDataset(examples=dev_examples, model=base_model)
dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=train_batch_size)

trying to connect to postgres...
connected to postgres
downloading data
trying to connect to postgres...
connected to postgres
downloading data


In [None]:
model.fit((train_dataloader, train_loss),
          None,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path_base=model_save_path,
          fp16=True,
          fp16_opt_level='O1'
)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
Iteration:   0%|          | 0/6823 [00:00<?, ?it/s][A

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic



Iteration:   0%|          | 1/6823 [00:00<1:19:04,  1.44it/s][A
Iteration:   0%|          | 2/6823 [00:00<1:02:57,  1.81it/s][A
Iteration:   0%|          | 3/6823 [00:01<51:20,  2.21it/s]  [A
Iteration:   0%|          | 4/6823 [00:01<43:40,  2.60it/s][A
Iteration:   0%|          | 5/6823 [00:01<37:25,  3.04it/s][A
Iteration:   0%|          | 6/6823 [00:01<33:20,  3.41it/s][A
Iteration:   0%|          | 7/6823 [00:01<30:22,  3.74it/s][A
Iteration:   0%|          | 8/6823 [00:02<28:03,  4.05it/s][A
Iteration:   0%|          | 9/6823 [00:02<26:33,  4.28it/s][A
Iteration:   0%|          | 10/6823 [00:02<25:10,  4.51it/s][A
Iteration:   0%|          | 11/6823 [00:02<24:41,  4.60it/s][A
Iteration:   0%|          | 12/6823 [00:02<24:13,  4.69it/s][A
Iteration:   0%|          | 13/6823 [00:03<23:41,  4.79it/s][A
Iteration:   0%|          | 14/6823 [00:03<23:14,  4.88it/s][A
Iteration:   0%|          | 15/6823 [00:03<23:00,  4.93it/s][A
Iteration:   0%|          | 16/6823 [00:03

Iteration:   2%|▏         | 121/6823 [00:24<21:30,  5.19it/s][A
Iteration:   2%|▏         | 122/6823 [00:24<21:22,  5.22it/s][A
Iteration:   2%|▏         | 123/6823 [00:24<21:22,  5.22it/s][A
Iteration:   2%|▏         | 124/6823 [00:25<21:29,  5.20it/s][A
Iteration:   2%|▏         | 125/6823 [00:25<22:01,  5.07it/s][A
Iteration:   2%|▏         | 126/6823 [00:25<22:24,  4.98it/s][A
Iteration:   2%|▏         | 127/6823 [00:25<21:38,  5.16it/s][A
Iteration:   2%|▏         | 128/6823 [00:25<22:00,  5.07it/s][A
Iteration:   2%|▏         | 129/6823 [00:26<22:00,  5.07it/s][A
Iteration:   2%|▏         | 130/6823 [00:26<21:51,  5.11it/s][A
Iteration:   2%|▏         | 131/6823 [00:26<21:41,  5.14it/s][A
Iteration:   2%|▏         | 132/6823 [00:26<22:37,  4.93it/s][A
Iteration:   2%|▏         | 133/6823 [00:26<23:15,  4.79it/s][A
Iteration:   2%|▏         | 134/6823 [00:27<22:32,  4.95it/s][A
Iteration:   2%|▏         | 135/6823 [00:27<22:24,  4.97it/s][A
Iteration:   2%|▏        

Iteration:   4%|▎         | 247/6823 [00:49<20:07,  5.44it/s][A
Iteration:   4%|▎         | 248/6823 [00:49<20:03,  5.46it/s][A
Iteration:   4%|▎         | 249/6823 [00:49<20:10,  5.43it/s][A
Iteration:   4%|▎         | 250/6823 [00:50<21:07,  5.19it/s][A
Iteration:   4%|▎         | 251/6823 [00:50<22:05,  4.96it/s][A
Iteration:   4%|▎         | 252/6823 [00:50<21:25,  5.11it/s][A
Iteration:   4%|▎         | 253/6823 [00:50<21:17,  5.14it/s][A
Iteration:   4%|▎         | 254/6823 [00:50<22:00,  4.97it/s][A
Iteration:   4%|▎         | 255/6823 [00:51<21:43,  5.04it/s][A
Iteration:   4%|▍         | 256/6823 [00:51<21:34,  5.07it/s][A
Iteration:   4%|▍         | 257/6823 [00:51<22:24,  4.88it/s][A
Iteration:   4%|▍         | 258/6823 [00:51<24:18,  4.50it/s][A
Iteration:   4%|▍         | 259/6823 [00:51<23:15,  4.70it/s][A
Iteration:   4%|▍         | 260/6823 [00:52<22:48,  4.79it/s][A
Iteration:   4%|▍         | 261/6823 [00:52<22:06,  4.95it/s][A
Iteration:   4%|▍        

Iteration:   5%|▌         | 373/6823 [01:14<20:53,  5.14it/s][A
Iteration:   5%|▌         | 374/6823 [01:14<20:39,  5.20it/s][A
Iteration:   5%|▌         | 375/6823 [01:15<21:33,  4.98it/s][A
Iteration:   6%|▌         | 376/6823 [01:15<21:48,  4.93it/s][A
Iteration:   6%|▌         | 377/6823 [01:15<21:29,  5.00it/s][A
Iteration:   6%|▌         | 378/6823 [01:15<21:24,  5.02it/s][A
Iteration:   6%|▌         | 379/6823 [01:15<20:55,  5.13it/s][A
Iteration:   6%|▌         | 380/6823 [01:16<20:37,  5.21it/s][A
Iteration:   6%|▌         | 381/6823 [01:16<21:02,  5.10it/s][A
Iteration:   6%|▌         | 382/6823 [01:16<21:13,  5.06it/s][A
Iteration:   6%|▌         | 383/6823 [01:16<21:14,  5.05it/s][A
Iteration:   6%|▌         | 384/6823 [01:16<20:46,  5.17it/s][A
Iteration:   6%|▌         | 385/6823 [01:17<20:49,  5.15it/s][A
Iteration:   6%|▌         | 386/6823 [01:17<21:09,  5.07it/s][A
Iteration:   6%|▌         | 387/6823 [01:17<20:48,  5.15it/s][A
Iteration:   6%|▌        

Iteration:   7%|▋         | 499/6823 [01:39<19:23,  5.44it/s][A
Iteration:   7%|▋         | 500/6823 [01:39<19:52,  5.30it/s][A
Iteration:   7%|▋         | 501/6823 [01:39<19:39,  5.36it/s][A
Iteration:   7%|▋         | 502/6823 [01:40<19:33,  5.39it/s][A
Iteration:   7%|▋         | 503/6823 [01:40<19:54,  5.29it/s][A
Iteration:   7%|▋         | 504/6823 [01:40<20:05,  5.24it/s][A
Iteration:   7%|▋         | 505/6823 [01:40<19:58,  5.27it/s][A
Iteration:   7%|▋         | 506/6823 [01:40<20:04,  5.25it/s][A
Iteration:   7%|▋         | 507/6823 [01:41<19:55,  5.28it/s][A
Iteration:   7%|▋         | 508/6823 [01:41<20:07,  5.23it/s][A
Iteration:   7%|▋         | 509/6823 [01:41<20:52,  5.04it/s][A
Iteration:   7%|▋         | 510/6823 [01:41<20:28,  5.14it/s][A
Iteration:   7%|▋         | 511/6823 [01:41<20:10,  5.21it/s][A
Iteration:   8%|▊         | 512/6823 [01:42<20:25,  5.15it/s][A
Iteration:   8%|▊         | 513/6823 [01:42<21:04,  4.99it/s][A
Iteration:   8%|▊        

Iteration:   9%|▉         | 625/6823 [02:04<20:07,  5.13it/s][A
Iteration:   9%|▉         | 626/6823 [02:04<19:34,  5.28it/s][A
Iteration:   9%|▉         | 627/6823 [02:05<19:03,  5.42it/s][A
Iteration:   9%|▉         | 628/6823 [02:05<19:05,  5.41it/s][A
Iteration:   9%|▉         | 629/6823 [02:05<18:55,  5.46it/s][A
Iteration:   9%|▉         | 630/6823 [02:05<19:26,  5.31it/s][A
Iteration:   9%|▉         | 631/6823 [02:05<19:37,  5.26it/s][A
Iteration:   9%|▉         | 632/6823 [02:06<19:39,  5.25it/s][A
Iteration:   9%|▉         | 633/6823 [02:06<19:31,  5.28it/s][A
Iteration:   9%|▉         | 634/6823 [02:06<19:13,  5.37it/s][A
Iteration:   9%|▉         | 635/6823 [02:06<20:17,  5.08it/s][A
Iteration:   9%|▉         | 636/6823 [02:06<19:52,  5.19it/s][A
Iteration:   9%|▉         | 637/6823 [02:07<19:34,  5.27it/s][A
Iteration:   9%|▉         | 638/6823 [02:07<20:19,  5.07it/s][A
Iteration:   9%|▉         | 639/6823 [02:07<20:07,  5.12it/s][A
Iteration:   9%|▉        

Iteration:  11%|█         | 751/6823 [02:29<22:11,  4.56it/s][A
Iteration:  11%|█         | 752/6823 [02:29<21:19,  4.74it/s][A
Iteration:  11%|█         | 753/6823 [02:30<21:29,  4.71it/s][A
Iteration:  11%|█         | 754/6823 [02:30<21:32,  4.69it/s][A
Iteration:  11%|█         | 755/6823 [02:30<21:01,  4.81it/s][A
Iteration:  11%|█         | 756/6823 [02:30<21:28,  4.71it/s][A
Iteration:  11%|█         | 757/6823 [02:30<21:32,  4.69it/s][A
Iteration:  11%|█         | 758/6823 [02:31<21:48,  4.63it/s][A
Iteration:  11%|█         | 759/6823 [02:31<21:22,  4.73it/s][A
Iteration:  11%|█         | 760/6823 [02:31<20:59,  4.82it/s][A
Iteration:  11%|█         | 761/6823 [02:31<21:09,  4.77it/s][A
Iteration:  11%|█         | 762/6823 [02:32<20:52,  4.84it/s][A
Iteration:  11%|█         | 763/6823 [02:32<21:19,  4.74it/s][A
Iteration:  11%|█         | 764/6823 [02:32<20:49,  4.85it/s][A
Iteration:  11%|█         | 765/6823 [02:32<20:04,  5.03it/s][A
Iteration:  11%|█        

Iteration:  13%|█▎        | 877/6823 [02:54<18:40,  5.31it/s][A
Iteration:  13%|█▎        | 878/6823 [02:54<19:07,  5.18it/s][A
Iteration:  13%|█▎        | 879/6823 [02:54<19:51,  4.99it/s][A
Iteration:  13%|█▎        | 880/6823 [02:55<19:43,  5.02it/s][A
Iteration:  13%|█▎        | 881/6823 [02:55<19:16,  5.14it/s][A
Iteration:  13%|█▎        | 882/6823 [02:55<19:18,  5.13it/s][A
Iteration:  13%|█▎        | 883/6823 [02:55<20:32,  4.82it/s][A
Iteration:  13%|█▎        | 884/6823 [02:55<20:16,  4.88it/s][A
Iteration:  13%|█▎        | 885/6823 [02:56<20:09,  4.91it/s][A
Iteration:  13%|█▎        | 886/6823 [02:56<19:56,  4.96it/s][A
Iteration:  13%|█▎        | 887/6823 [02:56<19:41,  5.02it/s][A
Iteration:  13%|█▎        | 888/6823 [02:56<19:18,  5.12it/s][A
Iteration:  13%|█▎        | 889/6823 [02:56<19:29,  5.07it/s][A
Iteration:  13%|█▎        | 890/6823 [02:57<20:31,  4.82it/s][A
Iteration:  13%|█▎        | 891/6823 [02:57<19:45,  5.00it/s][A
Iteration:  13%|█▎       

Iteration:  15%|█▍        | 1003/6823 [03:19<19:28,  4.98it/s][A
Iteration:  15%|█▍        | 1004/6823 [03:19<19:26,  4.99it/s][A
Iteration:  15%|█▍        | 1005/6823 [03:19<19:35,  4.95it/s][A
Iteration:  15%|█▍        | 1006/6823 [03:19<19:00,  5.10it/s][A
Iteration:  15%|█▍        | 1007/6823 [03:19<18:53,  5.13it/s][A
Iteration:  15%|█▍        | 1008/6823 [03:20<19:30,  4.97it/s][A
Iteration:  15%|█▍        | 1009/6823 [03:20<19:23,  5.00it/s][A
Iteration:  15%|█▍        | 1010/6823 [03:20<19:43,  4.91it/s][A
Iteration:  15%|█▍        | 1011/6823 [03:20<19:20,  5.01it/s][A
Iteration:  15%|█▍        | 1012/6823 [03:20<18:52,  5.13it/s][A
Iteration:  15%|█▍        | 1013/6823 [03:21<18:55,  5.12it/s][A
Iteration:  15%|█▍        | 1014/6823 [03:21<18:47,  5.15it/s][A
Iteration:  15%|█▍        | 1015/6823 [03:21<18:51,  5.13it/s][A
Iteration:  15%|█▍        | 1016/6823 [03:21<18:58,  5.10it/s][A
Iteration:  15%|█▍        | 1017/6823 [03:21<19:03,  5.08it/s][A
Iteration:

Iteration:  17%|█▋        | 1127/6823 [03:43<19:36,  4.84it/s][A
Iteration:  17%|█▋        | 1128/6823 [03:43<19:38,  4.83it/s][A
Iteration:  17%|█▋        | 1129/6823 [03:43<19:11,  4.94it/s][A
Iteration:  17%|█▋        | 1130/6823 [03:44<19:53,  4.77it/s][A
Iteration:  17%|█▋        | 1131/6823 [03:44<19:14,  4.93it/s][A
Iteration:  17%|█▋        | 1132/6823 [03:44<18:59,  4.99it/s][A
Iteration:  17%|█▋        | 1133/6823 [03:44<18:52,  5.02it/s][A
Iteration:  17%|█▋        | 1134/6823 [03:44<18:49,  5.04it/s][A
Iteration:  17%|█▋        | 1135/6823 [03:45<18:43,  5.06it/s][A
Iteration:  17%|█▋        | 1136/6823 [03:45<19:49,  4.78it/s][A
Iteration:  17%|█▋        | 1137/6823 [03:45<19:08,  4.95it/s][A
Iteration:  17%|█▋        | 1138/6823 [03:45<19:23,  4.89it/s][A
Iteration:  17%|█▋        | 1139/6823 [03:45<18:55,  5.00it/s][A
Iteration:  17%|█▋        | 1140/6823 [03:46<19:04,  4.96it/s][A
Iteration:  17%|█▋        | 1141/6823 [03:46<18:58,  4.99it/s][A
Iteration:

Iteration:  18%|█▊        | 1251/6823 [04:08<18:55,  4.91it/s][A
Iteration:  18%|█▊        | 1252/6823 [04:08<18:39,  4.98it/s][A
Iteration:  18%|█▊        | 1253/6823 [04:08<17:58,  5.16it/s][A
Iteration:  18%|█▊        | 1254/6823 [04:08<18:28,  5.02it/s][A
Iteration:  18%|█▊        | 1255/6823 [04:09<18:20,  5.06it/s][A
Iteration:  18%|█▊        | 1256/6823 [04:09<18:09,  5.11it/s][A
Iteration:  18%|█▊        | 1257/6823 [04:09<17:56,  5.17it/s][A
Iteration:  18%|█▊        | 1258/6823 [04:09<17:56,  5.17it/s][A
Iteration:  18%|█▊        | 1259/6823 [04:09<18:24,  5.04it/s][A
Iteration:  18%|█▊        | 1260/6823 [04:10<18:12,  5.09it/s][A
Iteration:  18%|█▊        | 1261/6823 [04:10<17:54,  5.18it/s][A
Iteration:  18%|█▊        | 1262/6823 [04:10<17:26,  5.31it/s][A
Iteration:  19%|█▊        | 1263/6823 [04:10<17:52,  5.18it/s][A
Iteration:  19%|█▊        | 1264/6823 [04:10<17:36,  5.26it/s][A
Iteration:  19%|█▊        | 1265/6823 [04:11<17:43,  5.23it/s][A
Iteration:

Iteration:  20%|██        | 1375/6823 [04:32<18:00,  5.04it/s][A
Iteration:  20%|██        | 1376/6823 [04:33<17:49,  5.09it/s][A
Iteration:  20%|██        | 1377/6823 [04:33<18:56,  4.79it/s][A
Iteration:  20%|██        | 1378/6823 [04:33<19:45,  4.59it/s][A
Iteration:  20%|██        | 1379/6823 [04:33<18:51,  4.81it/s][A
Iteration:  20%|██        | 1380/6823 [04:33<18:17,  4.96it/s][A
Iteration:  20%|██        | 1381/6823 [04:34<18:10,  4.99it/s][A
Iteration:  20%|██        | 1382/6823 [04:34<17:48,  5.09it/s][A
Iteration:  20%|██        | 1383/6823 [04:34<17:44,  5.11it/s][A
Iteration:  20%|██        | 1384/6823 [04:34<17:44,  5.11it/s][A
Iteration:  20%|██        | 1385/6823 [04:34<17:57,  5.05it/s][A
Iteration:  20%|██        | 1386/6823 [04:35<17:56,  5.05it/s][A
Iteration:  20%|██        | 1387/6823 [04:35<17:19,  5.23it/s][A
Iteration:  20%|██        | 1388/6823 [04:35<17:34,  5.16it/s][A
Iteration:  20%|██        | 1389/6823 [04:35<18:21,  4.93it/s][A
Iteration:

Iteration:  22%|██▏       | 1499/6823 [04:57<17:31,  5.07it/s][A
Iteration:  22%|██▏       | 1500/6823 [04:57<17:38,  5.03it/s][A
Iteration:  22%|██▏       | 1501/6823 [04:58<17:34,  5.05it/s][A
Iteration:  22%|██▏       | 1502/6823 [04:58<17:28,  5.08it/s][A
Iteration:  22%|██▏       | 1503/6823 [04:58<17:19,  5.12it/s][A
Iteration:  22%|██▏       | 1504/6823 [04:58<17:26,  5.08it/s][A
Iteration:  22%|██▏       | 1505/6823 [04:58<17:23,  5.10it/s][A
Iteration:  22%|██▏       | 1506/6823 [04:59<17:21,  5.10it/s][A
Iteration:  22%|██▏       | 1507/6823 [04:59<17:56,  4.94it/s][A
Iteration:  22%|██▏       | 1508/6823 [04:59<17:24,  5.09it/s][A
Iteration:  22%|██▏       | 1509/6823 [04:59<18:22,  4.82it/s][A
Iteration:  22%|██▏       | 1510/6823 [04:59<18:01,  4.91it/s][A
Iteration:  22%|██▏       | 1511/6823 [05:00<17:26,  5.07it/s][A
Iteration:  22%|██▏       | 1512/6823 [05:00<17:09,  5.16it/s][A
Iteration:  22%|██▏       | 1513/6823 [05:00<17:48,  4.97it/s][A
Iteration:

Iteration:  24%|██▍       | 1623/6823 [05:22<17:35,  4.93it/s][A
Iteration:  24%|██▍       | 1624/6823 [05:22<17:12,  5.03it/s][A
Iteration:  24%|██▍       | 1625/6823 [05:22<16:58,  5.11it/s][A
Iteration:  24%|██▍       | 1626/6823 [05:22<17:15,  5.02it/s][A
Iteration:  24%|██▍       | 1627/6823 [05:23<17:06,  5.06it/s][A
Iteration:  24%|██▍       | 1628/6823 [05:23<17:16,  5.01it/s][A
Iteration:  24%|██▍       | 1629/6823 [05:23<17:46,  4.87it/s][A
Iteration:  24%|██▍       | 1630/6823 [05:23<17:29,  4.95it/s][A
Iteration:  24%|██▍       | 1631/6823 [05:23<17:17,  5.00it/s][A
Iteration:  24%|██▍       | 1632/6823 [05:24<17:35,  4.92it/s][A
Iteration:  24%|██▍       | 1633/6823 [05:24<16:53,  5.12it/s][A
Iteration:  24%|██▍       | 1634/6823 [05:24<17:04,  5.06it/s][A
Iteration:  24%|██▍       | 1635/6823 [05:24<17:00,  5.08it/s][A
Iteration:  24%|██▍       | 1636/6823 [05:24<16:54,  5.11it/s][A
Iteration:  24%|██▍       | 1637/6823 [05:25<16:48,  5.14it/s][A
Iteration:

Iteration:  26%|██▌       | 1747/6823 [05:46<17:08,  4.94it/s][A
Iteration:  26%|██▌       | 1748/6823 [05:47<17:30,  4.83it/s][A
Iteration:  26%|██▌       | 1749/6823 [05:47<16:57,  4.99it/s][A
Iteration:  26%|██▌       | 1750/6823 [05:47<16:38,  5.08it/s][A
Iteration:  26%|██▌       | 1751/6823 [05:47<17:05,  4.94it/s][A
Iteration:  26%|██▌       | 1752/6823 [05:47<16:29,  5.12it/s][A
Iteration:  26%|██▌       | 1753/6823 [05:48<16:35,  5.09it/s][A
Iteration:  26%|██▌       | 1754/6823 [05:48<16:46,  5.04it/s][A
Iteration:  26%|██▌       | 1755/6823 [05:48<16:33,  5.10it/s][A
Iteration:  26%|██▌       | 1756/6823 [05:48<15:58,  5.28it/s][A
Iteration:  26%|██▌       | 1757/6823 [05:48<16:34,  5.10it/s][A
Iteration:  26%|██▌       | 1758/6823 [05:49<16:24,  5.15it/s][A
Iteration:  26%|██▌       | 1759/6823 [05:49<17:00,  4.96it/s][A
Iteration:  26%|██▌       | 1760/6823 [05:49<16:31,  5.11it/s][A
Iteration:  26%|██▌       | 1761/6823 [05:49<16:39,  5.07it/s][A
Iteration:

In [13]:
import torch
save_path = f'{model_save_path}/Query'
os.makedirs(save_path, exist_ok=True)
torch.save(model.model_b.linear.state_dict(), os.path.join(save_path, 'pytorch_model.bin'))

In [14]:
load_path = os.path.join(model_save_path, 'Query', 'pytorch_model.bin')
model.model_b.linear.load_state_dict(torch.load(load_path))

<All keys matched successfully>

In [4]:
HOST = '54.196.150.193'
USER = 'postgres'
PASS = os.environ.get('PGPASS')
PGSSLROOTCERT = os.environ.get('PGSSLROOTCERT')
if PASS == None or PGSSLROOTCERT == None:
    print("Please set PG_PASS and PGSSLROOTCERT env variable")
    raise SystemExit()
DBNAME = 'fever'
POSTGRES_DSN = f'''dbname='fever' user='{USER}' host='{HOST}' password='{PASS}' '''
model = SentenceTransformer('./fever-model/')

In [24]:
query = '''
select a.id, l.text 
from wiki.articles a
join wiki.lines l on l.article_id = a.id and line_number = 0
'''

In [25]:
conn = psycopg2.connect(POSTGRES_DSN)
cur = conn.cursor()
cur.execute(query)
res = cur.fetchall()

In [None]:
sent_buffer = []
ids_buffer = []
BATCH_SIZE = 100000
if not os.path.exists('./fever-embs/'):
    os.makedirs('./fever-embs/')
start = time.time()
for i, batch in enumerate(res):
    sent_buffer.append(batch[1])
    ids_buffer.append(batch[0])
    if (i+1) % BATCH_SIZE == 0:
        embs = model.encode(sent_buffer, batch_size=32)
        ids_buffer = np.array(ids_buffer)
        ids_buffer = np.expand_dims(ids_buffer, 1)
        to_save = np.concatenate((ids_buffer, embs), 1)
        np.save(f'./fever-embs/emb-{i}',to_save)
        sent_buffer = []
        ids_buffer = []
        print(f'Running {i/(time.time() - start)} per second')

Running 339.2325537255288 per second


In [29]:
#embs = model.encode(sent_buffer, batch_size=32)
#ids_buffer = np.array(ids_buffer)
#ids_buffer = np.expand_dims(ids_buffer, 1)
#to_save = np.concatenate((ids_buffer, embs), 1)
np.save(f'./fever-embs/emb-last',to_save)