# Mounting drive to save checkpoints there

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Transformers

In [2]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/88/b1/41130a228dd656a1a31ba281598a968320283f48d42782845f6ba567f00b/transformers-4.2.2-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 8.1MB/s 
Collecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 37.2MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 51.7MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=839591ec95cc1448c70

In [3]:
from transformers import BertForSequenceClassification
from transformers import BertTokenizer
import torch
from typing import List
import random
import sklearn
from math import ceil

In [4]:
class JudgeBERT(torch.nn.Module):
  """
  Adds a new head on top of the pre-trained BERT.
  """
  def __init__(self, freeze_base: bool, device:torch.device):
    """
    Ctor.
    :param freeze_base: If True the only the head layers will be trained.
    """
    torch.nn.Module.__init__(self)
    self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased').bert
    self.device = device
    self.head = torch.nn.Sequential(
        torch.nn.Linear(in_features=768, out_features=2),
        torch.nn.Softmax(dim=1)
    )
    self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    if freeze_base:
      for param in self.bert.parameters():
        param.requires_grad = False
  
  def forward(self, texts: List[str]) -> torch.Tensor:
    encoding = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    y_hat = self.bert(encoding['input_ids'].to(self.device), encoding['attention_mask'].to(self.device))
    y_hat = self.head(y_hat[1])
    return y_hat

In [5]:
def train(model: torch.nn.Module, data: List[str], labels: List[str], batch_size: int, optimizer, verbose=False):
  """
  Trains the network for one epoch.
  """
  # Set to training mode.
  model.train(True)

  for start_i in range(0, len(data), batch_size):
    x = data[start_i:start_i+batch_size]
    y = labels[start_i:start_i+batch_size]
    # Convert the labels to torch tensor. Violation is 0, non-violation is 1.
    y = torch.tensor([0 if l == "violation" else 1 for l in y], dtype=torch.long, device=model.device)
    out = model(x)
    loss = torch.nn.functional.cross_entropy(input=out, target=y)
    if verbose:
      print(f"loss: {loss.item()}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


def test(model: torch.nn.Module, data: List[str], labels: List[str], batch_size: int = 32) -> float:
  """
  Calculates classification accuracy on the dataset and returns the result.
  """
  model.train(False)
  accuracy = 0
  weight = 0
  for start_i in range(0, len(data), batch_size):
    x = data[start_i:start_i+batch_size]
    y = labels[start_i:start_i+batch_size]
    weight += len(x)
    with torch.no_grad():
      y = torch.tensor([0 if l == "violation" else 1 for l in y], dtype=torch.long)
      y_hat = torch.max(model(x), dim=1)[1].cpu()
    accuracy += len(x) * sklearn.metrics.accuracy_score(y, y_hat)
  return accuracy / weight

def save(model, optimizer, epoch):
  torch.save({
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'epoch': epoch
      }, f"/content/drive/My Drive/PR checkpoints/checkpoint_plus.pt")


# Loading the data

In [6]:
!unzip -qq crystal_ball_data.zip

In [7]:
from __future__ import print_function
import re, glob, sys, time, os, random
from time import gmtime, strftime
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score, precision_score, f1_score, classification_report
from sklearn.metrics import confusion_matrix
from sklearn.svm import LinearSVC
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.pipeline import Pipeline, FeatureUnion
import pandas as pd
import warnings
from sklearn.model_selection import cross_val_predict, cross_val_score
from statistics import mean
from datetime import datetime
from time import time
import logging
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.pipeline import Pipeline, FeatureUnion
warnings.filterwarnings("ignore", category=UserWarning)
import pprint
from random import shuffle


pipeline = Pipeline([
	('tfidf', TfidfVectorizer(analyzer='word')),
	('clf', LinearSVC())
])


parameters = {
	'tfidf__ngram_range': [(1,2),(1,1),(1,3),(1,4),(2,2),(2,3),(2,4),(3,3),(3,4),(4,4)],
	#'tfidf__analyzer': ('word', 'char'),
	'tfidf__lowercase': (True, False),
	#'tfidf__max_df': (0.01, 1.0), # ignore words that occur as more than 1% of corpus
	'tfidf__min_df': (1, 2, 3), # we need to see a word at least (once, twice, thrice)
	'tfidf__use_idf': (False, True),
	#'tfidf__sublinear_tf': (False, True),
	'tfidf__binary': (False, True),
	'tfidf__norm': (None, 'l1', 'l2'),
	#'tfidf__max_features': (None, 2000, 5000),
	'tfidf__stop_words': (None, 'english'),

	#'tfidfchar_ngram_range': ((1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(2,2),(2,3),(2,4),(2,5),(2,6),(3,3),(3,4),(3,5),(3,6),(4,4),(4,5),(4,6),(5,5),(5,6),(1,7),(2,7),(3,7),(4,7),(5,7),(6,7),(7,7)),
	
	
	'clf__C':(0.1, 1, 5)
}


def balance(Xtrain,Ytrain):
	#v = Ytrain.count('violation')
   # nv = Ytrain.count('non-violation')
	#print(v, nv)
	v = [i for i,val in enumerate(Ytrain) if val=='violation']
	nv = [i for i,val in enumerate(Ytrain) if val=='non-violation']
	if len(nv) < len(v):
		v = v[:len(nv)]
		Xtrain = [Xtrain[j] for j in v] + [Xtrain[i] for i in nv]
		Ytrain = [Ytrain[j] for j in v] + [Ytrain[i] for i in nv]
	if len(nv) > len(v):
		nv = nv[:len(v)]
		Xtrain = [Xtrain[j] for j in v] + [Xtrain[i] for i in nv]
		Ytrain = [Ytrain[j] for j in v] + [Ytrain[i] for i in nv]
	
	#print(Ytrain.count('violation'),Ytrain.count('non-violation'))
	#print('LEN', len(Xtrain), len(Ytrain))
	return Xtrain, Ytrain
	
	

def extract_text(starts, ends, cases, violation):
	facts = []
	D = []
	years = []
	for case in cases:
		contline = ''
		year = 0
		with open(case, 'r') as f:
			for line in f:
				#print(line)
				dat = re.search('^([0-9]{1,2}\s\w+\s([0-9]{4}))', line)
				if dat != None:
					year = int(dat.group(2))
					break
			if year>0:
				years.append(year)
				#print(year)
				wr = 0
				for line in f:
					if wr == 0:
						if re.search(starts, line) != None:
							wr = 1
					if wr == 1 and re.search(ends, line) == None:
						contline += line
						contline += '\n'
					elif re.search(ends, line) != None:
						break
				facts.append(contline)
	for i in range(len(facts)):
		D.append((facts[i], violation, years[i])) 
	return D

def extract_parts(article, violation, part, path):
  # Path is the path to the folder that contains all the text files.
	from os import listdir
	from os.path import isfile, join
	cases = [join(path, f) for f in listdir(path)]
	# cases = glob.glob(path)
	#print(cases)

		
	facts = []
	D = []
	years = []
	
	if part == 'relevant_law':
		for case in cases:
			year = 0
			contline = ''
			with open(case, 'r') as f:
				for line in f:
					dat = re.search('^([0-9]{1,2}\s\w+\s([0-9]{4}))', line)
					if dat != None:
						 #date = dat.group(1)
						year = int(dat.group(2))
						break
				if year> 0:
					years.append(year)
					wr = 0
					for line in f:
						if wr == 0:
							if re.search('RELEVANT', line) != None:
								wr = 1
						if wr == 1 and re.search('THE LAW', line) == None and re.search('PROCEEDINGS', line) == None:
							contline += line
							contline += '\n'
						elif re.search('THE LAW', line) != None or re.search('PROCEEDINGS', line) != None:
							break
					facts.append(contline)
		for i in range(len(facts)):
			D.append((facts[i], violation, years[i]))
		
	if part == 'facts':
		starts = 'THE FACTS'
		ends ='THE LAW'
		D = extract_text(starts, ends, cases, violation)
	if part == 'circumstances':
		starts = 'CIRCUMSTANCES'
		ends ='RELEVANT'
		D = extract_text(starts, ends, cases, violation)
	if part == 'procedure':
		starts = 'PROCEDURE'
		ends ='THE FACTS'
		D = extract_text(starts, ends, cases, violation)
	if part == 'procedure+facts':
		starts = 'PROCEDURE'
		ends ='THE LAW'
		D = extract_text(starts, ends, cases, violation)
	return D


def fetch(part, path, article):
  train_v = extract_parts(article, 'violation', part, path+'/train/'+article+'/violation/')
  train_nv = extract_parts(article, 'non-violation', part, path+'/train/'+article+'/non-violation/')
  test_v = extract_parts(article, 'violation', part, path+'/test20/'+article+'/violation/')
  test_nv = extract_parts(article, 'non-violation', part, path+'/test20/'+article+'/non-violation/')
 
  return train_v, train_nv, test_v, test_nv


def get_facts_dataset(prt, articles: List[int], shuffle: bool = False):
  """
  Returns a tuple of (training_data, training_labels, test_data, test labels)
  containing the data from the given articles. The 'data' fields are lists of strings
  that contain the FACTS part of the cases, while the 'labels' fields are also lists of
  strings containing either 'violation' or 'non-violation' for their respective data
  counterparts.
  :param articles: List of integers of article numbers.
  :param shuffle: Randomly shuffles the training set if True.
  """
  path = "/content/crystal_ball_data"
  traind = []
  trainl = []
  testd = []
  testl = []
  for i in articles:
    art = f"Article{i}"
    trv, trnv, tev, tenv = fetch(prt, path, art)
    traind.extend([e[0] for e in trv] + [e[0] for e in trnv])
    trainl.extend([e[1] for e in trv] + [e[1] for e in trnv])

    testd.extend([e[0] for e in tev] + [e[0] for e in tenv])
    testl.extend([e[1] for e in tev] + [e[1] for e in tenv])
  
  if shuffle:
    c = list(zip(traind, trainl))
    random.shuffle(c)

    traind, trainl = zip(*c)
  
  return traind, trainl, testd, testl

# Training

In [8]:
jb = JudgeBERT(freeze_base=False, device=torch.device('cuda:0')).to(torch.device("cuda:0"))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [9]:
optimizer = torch.optim.Adam(params=jb.parameters(), lr=1e-5)

In [12]:
import time

t = time.time()
time.sleep(5)
duration = time.time() - t
print(f"{duration / 60 :.3f} min")

0.083 min


In [14]:
# Training and testing on the following articles:
article_numbers = [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 18]
batch_size = 8
training_data, training_labels, test_data, test_labels = get_facts_dataset('procedure+facts', article_numbers, shuffle=True)

print(f"Size of the training data is: {len(training_data)}. Completing an epoch with batch size of {batch_size} will take {ceil(len(training_data) / batch_size)} iterations.")
t = time.time()
acc = test(jb, test_data, test_labels)
duration = time.time() - t
print(f"Duration: {duration / 60 :.3f} min")
print(acc)
for epoch in range(10):
  train(jb, training_data, training_labels, batch_size=batch_size, optimizer=optimizer, verbose=False)
  acc = test(jb, test_data, test_labels)
  save(model=jb, optimizer=optimizer, epoch=epoch+1)
  print(f"=== Epoch {epoch+1} completed. ===")
  print(f"| Test accuracy: {acc*100:.3f}%")

Size of the training data is: 3214. Completing an epoch with batch size of 8 will take 402 iterations.
Duration: 2.367 min
0.6977611940298507


KeyboardInterrupt: ignored

# Calcuating precision recall etc

In [None]:
def get_precision_recall_fscore_single(model: torch.nn.Module, data: List[str], labels: List[str], batch_size: int = 32):
  model.train(False)
  prec = 0
  recall = 0
  fscore = 0
  weight = 0
  for start_i in range(0, len(data), batch_size):
    x = data[start_i:start_i+batch_size]
    y = labels[start_i:start_i+batch_size]
    weight += len(x)
    with torch.no_grad():
      y = torch.tensor([0 if l == "violation" else 1 for l in y], dtype=torch.long)
      y_hat = torch.max(model(x), dim=1)[1].cpu()
    p, r, f, _ = sklearn.metrics.precision_recall_fscore_support(y, y_hat)
    prec += len(x) * p
    recall += len(x) * r
    fscore = len(x) * f
  return prec / weight, recall / weight, fscore / weight

def get_precision_recall_fscore(prt, model, article_numbers) -> dict:
  """
  Call this function to obtain per article network performance.
  It returns a dict with the article numbers as keys, and another dict
  as value, which contains 3 keys: {'precision', 'recall', 'f-score'}.
  """
  stats = dict()
  for i in article_numbers:
    _, _, test_data, test_labels = get_facts_dataset(prt, [i], shuffle=False)
    if i == 6:
      print(test_labels)
    p, r, f = get_precision_recall_fscore_single(model, test_data, test_labels)
    acc = test(jb, test_data, test_labels)
    print(f"acc {i}: {acc:.3f}")
    stats[i] = {"precision": p, "recall": r, "f-score": f}
  return stats

In [None]:
# Path to the saved checkpoint. Replace with yours.
# Alternatively comment the whole loadin thing out in case you just
# finished training, and the network is still in memory.

article_numbers = [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 18]
training_data, training_labels, test_data, test_labels = get_facts_dataset('facts', article_numbers, shuffle=True)

path = f"/content/drive/My Drive/PR checkpoints/checkpoint_e10.pt"
ckpt = torch.load(path)
print("Checkpoint loaded.")

jb.load_state_dict(ckpt["model_state_dict"])
print("Network parameters loaded.")


# acc = test(jb, test_data, test_labels)
# print(f"| Test accuracy: {acc*100:.3f}%")
eval_stats = get_precision_recall_fscore('facts', jb, [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 18])

acc 2: 0.929
acc 3: 0.754
acc 4: 1.000
acc 5: 0.737
['violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation', 'violation'

In [None]:
# Here is how to inspect the results. Prints the results of article 10.
# Notice that there are two values: the fist value is for violation,
# the second value is for non-violation.
print(eval_stats[2])

print("               PRECISION        RECALL    F-SCORE")
for k in eval_stats.keys():
  # print(eval_stats[k])
  print(f"Art {k}   non-violation: {eval_stats[k]['precision'][1]:.2f}  {eval_stats[k]['recall'][1]:.2f}  {eval_stats[k]['f-score'][1]:.2f}")
  print(f"Art {k}   violation: {eval_stats[k]['precision'][0]:.2f}  {eval_stats[k]['recall'][0]:.2f}  {eval_stats[k]['f-score'][0]:.2f}")
  print()

{'precision': array([1.   , 0.875]), 'recall': array([0.85714286, 1.        ]), 'f-score': array([0.92307692, 0.93333333])}
               PRECISION        RECALL    F-SCORE
Art 2   non-violation: 0.88  1.00  0.93
Art 2   violation: 1.00  0.86  0.92

Art 3   non-violation: 0.52  0.41  0.08
Art 3   violation: 0.56  0.48  0.00

Art 4   non-violation: 1.00  1.00  1.00
Art 4   violation: 1.00  1.00  1.00

Art 5   non-violation: 0.55  0.39  0.14
Art 5   violation: 0.56  0.68  0.00

Art 6   non-violation: 0.55  0.42  0.01
Art 6   violation: 0.54  0.47  0.00

Art 7   non-violation: 0.57  0.67  0.62
Art 7   violation: 0.60  0.50  0.55

Art 8   non-violation: 0.53  0.63  0.14
Art 8   violation: 0.54  0.37  0.00

Art 10   non-violation: 0.58  0.94  0.38
Art 10   violation: 0.59  0.33  0.00

Art 11   non-violation: 0.86  0.75  0.80
Art 11   violation: 0.78  0.88  0.82

Art 12   non-violation: 0.50  1.00  0.67
Art 12   violation: 0.00  0.00  0.00

Art 13   non-violation: 0.70  0.77  0.34
Art 13   