<a href="https://colab.research.google.com/github/marcovzla/discobert/blob/master/discobert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install ipython tqdm requests boto3 regex click joblib nltk scikit-learn jupyter
!pip install torch torchvision transformers
!git clone https://github.com/marcovzla/discobert.git



fatal: destination path 'discobert' already exists and is not an empty directory.


## Data is mounted in GDrive

In [0]:
# from google.colab import drive
# drive.mount('/content/drive')

## Running the model... from the github dir

In [1]:
cd discobert


/content/discobert


Pull the repo for latest code

In [2]:
!git pull origin master
# ls

From https://github.com/marcovzla/discobert
 * branch            master     -> FETCH_HEAD
Already up to date.


# Training code

In [0]:
import os
import torch
from torch.optim import Adam
from tqdm.autonotebook import tqdm
from model import DiscoBertModel
from rst import load_annotations, iter_spans_only

DATA="../drive/My Drive/discobert/data/"
train_dir = os.path.join(DATA, 'training_subset')
# train_dir = os.path.join(DATA, 'validation')
val_dir = os.path.join(DATA, 'validation')
model_dir = "../drive/My Drive/discobert/models/colab"
lr = 1e-3
num_epochs = 10
device = 'cuda'

def train(num_epochs, learning_rate, device, train_dir, val_dir, model_dir):
    torch.cuda.empty_cache()
    discobert = DiscoBertModel()
    discobert.set_device(device, init_weights=True)
    discobert.to(device)

    # setup the optimizer, loss, etc
    optimizer = Adam(params=discobert.parameters(), lr=learning_rate)

    # for each epoch
    for epoch_i in range(num_epochs):
        with open(os.path.join(model_dir, "log.txt"), 'a') as logfile:
            print(f'Beginning epoch {epoch_i}')
            print(f'Beginning epoch {epoch_i}', file=logfile)
        for annotation in tqdm(list(load_annotations(train_dir))):
            discobert.zero_grad()
            loss, pred_tree = discobert(annotation.edus, annotation.dis)
            loss.backward()
            optimizer.step()
        with open(os.path.join(model_dir, "log.txt"), 'a') as logfile:
            print(f'Finished epoch {epoch_i}')
            print(f'Finished epoch {epoch_i}', file=logfile)

        # save model
        epoch_model_dir = os.path.join(model_dir, f'discobert_{epoch_i}')
        if not os.path.exists(epoch_model_dir):
            os.makedirs(epoch_model_dir)
        discobert.save_pretrained(epoch_model_dir)
        # evaluate on validation
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        #gc.collect()
        predict(val_dir, discobert)

def predict(data_dir, discobert):
    # discobert = DiscoBertModel.from_pretrained(model_dir)
    # discobert.set_device(device, init_weights=False)
    # discobert.to(device)

    all_gold_nodes = []
    all_pred_nodes = []
    for annotation in tqdm(list(load_annotations(data_dir))):
        pred_tree = discobert(annotation.edus)[0]
        all_gold_nodes.extend(annotation.dis.get_nonterminals())
        all_pred_nodes.extend(pred_tree.get_nonterminals())

        all_gold_spans = [f'{annotation.docid}_{x}' for x in list(iter_spans_only(all_gold_nodes))]
        all_pred_spans = [f'{annotation.docid}_{x}' for x in list(iter_spans_only(all_pred_nodes))]

    # print(all_gold_spans[3])
    # print(all_pred_spans[3])
    p, r, f1 = eval(all_gold_spans, all_pred_spans) # TODO confirm
    with open(os.path.join(model_dir, "log.txt"), 'a') as logfile:
      print(f'P:{p}\tR:{r}\tF1:{f1}')
      print(f'P:{p}\tR:{r}\tF1:{f1}', file=logfile)
    

def eval(gold, pred):
    TP, FP, FN = 0, 0, 0
    for g in gold:
        if g in pred:
            TP += 1
        else:
            FN += 1

    for p in pred:
        if p not in gold:
            FP += 1

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1 = 2 * ((precision * recall) / (precision + recall))
    return precision, recall, f1

# train(num_epochs, lr, device, train_dir, val_dir, model_dir)

# Actually run the training

In [6]:
train(num_epochs, lr, device, train_dir, val_dir, model_dir)

Beginning epoch 0


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 0


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
Beginning epoch 1


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 1


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.22649140546006066	R:0.09076175040518639	F1:0.12959213190627714
Beginning epoch 2


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 2


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
Beginning epoch 3


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 3


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
Beginning epoch 4


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 4


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.22649140546006066	R:0.09076175040518639	F1:0.12959213190627714
Beginning epoch 5


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 5


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
Beginning epoch 6


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 6


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
Beginning epoch 7


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 7


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.22649140546006066	R:0.09076175040518639	F1:0.12959213190627714
Beginning epoch 8


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 8


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
Beginning epoch 9


HBox(children=(IntProgress(value=0, max=292), HTML(value='')))


Finished epoch 9


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


P:0.34905347060777153	R:0.4258508914100486	F1:0.38364665084869504
