In [62]:
import sys
sys.path.append('../')
import torch, math
import torch.nn as nn
from torch.utils.data import DataLoader
from dataset.dataset import TokenDataset, DataCollator
from model.transformer import TransformerModel
from model.config import ModelConfig

In [69]:
def process_batch(batch, model, mse_criterion, ce_criterion, device):
    inputs = {
        'tokens': batch.tokens.to(device),
        'token_types': batch.token_types.to(device),
        'scope_depth': batch.scope_depth.float().to(device),
        'attention_mask': batch.attention_mask.float().to(device)
    }

    spaces, newlines = batch.spaces.float().to(device), batch.newlines.long().to(device)
    attn_mask = inputs['attention_mask']
    max_len = attn_mask.size(1)
    spaces, newlines, attn_mask = spaces[:, :max_len], newlines[:, :max_len], attn_mask[:, :max_len]

    space_output, newline_logits = model(**inputs)

    batch_preds = {
        'space_preds': space_output,
        'newline_preds': newline_logits.argmax(dim=-1),
        'true_spaces': spaces,
        'true_newlines': newlines,
        'attention_mask': attn_mask
    }

    raw_space_loss = mse_criterion(space_output, spaces) * attn_mask
    space_loss = raw_space_loss.sum() / attn_mask.sum()
    
    reshaped_logits = newline_logits.view(-1, model.config.max_newlines + 1)
    newline_raw_loss = ce_criterion(reshaped_logits, newlines.view(-1)).view_as(spaces)
    newline_loss = (newline_raw_loss * attn_mask).sum() / attn_mask.sum()
    loss = (space_loss + newline_loss) / 2

    space_mse = ((space_output - spaces) ** 2 * attn_mask).sum().item()
    newline_correct = ((newline_logits.argmax(dim=-1) == newlines).float() * attn_mask).sum().item()
    num_tokens = attn_mask.sum().item()

    rounded_space_preds = torch.round(space_output)
    space_correct = ((rounded_space_preds == spaces).float() * attn_mask).sum().item()
    
    critical_errors = ((space_output < 0.1) & (spaces >= 1) & (attn_mask == 1)).sum().item()
    
    return batch_preds, loss.item(), space_mse, newline_correct, num_tokens, critical_errors, space_correct

def evaluate(model, data_loader, device='cuda'):
    model.eval()
    total_loss = total_critical_errors = total_space_mse = total_newline_correct = total_tokens = total_space_correct = 0
    mse_criterion = nn.MSELoss(reduction='none')
    ce_criterion = nn.CrossEntropyLoss(reduction='none')
    
    predictions = []
    
    with torch.no_grad():
        for batch in data_loader:
            batch_inputs = (batch, model, mse_criterion, ce_criterion, device)
            batch_preds, loss, space_mse, newline_correct, num_tokens, critical_errors, space_correct = process_batch(*batch_inputs)
            
            predictions.append(batch_preds)
            total_loss += loss
            total_space_mse += space_mse
            total_newline_correct += newline_correct
            total_tokens += num_tokens
            total_critical_errors += critical_errors
            total_space_correct += space_correct
    
    metrics = {
        'avg_loss': total_loss / len(data_loader),
        'space_mse': total_space_mse / total_tokens,
        'space_accuracy': total_space_correct / total_tokens * 100,
        'newline_accuracy': total_newline_correct / total_tokens * 100,
        'critical_error_rate': total_critical_errors / total_tokens * 100
    }
    
    return metrics, predictions

In [70]:
def load_model(checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print("\n\n")
    config_dict = checkpoint['model_config']
    token_to_idx = checkpoint['token_to_idx']
    type_to_idx = checkpoint['type_to_idx']

    model_config = ModelConfig(
        vocab_size=len(token_to_idx),
        type_vocab_size=len(type_to_idx),
        max_newlines=config_dict['max_newlines'],
        d_model=config_dict.get('d_model', 256),
        nhead=config_dict.get('nhead', 8),
        num_encoder_layers=config_dict.get('num_encoder_layers', 6),
        dim_feedforward=config_dict.get('dim_feedforward', 1024),
        dropout=config_dict.get('dropout', 0.1),
        max_seq_length=config_dict.get('max_seq_length', 2048)
    )

    model = TransformerModel(model_config).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    return model, token_to_idx, type_to_idx

def load_test_data(test_data_path, token_to_idx, type_to_idx, batch_size):
    test_dataset = TokenDataset(test_data_path, token_to_idx, type_to_idx)
    test_loader_args = { 'batch_size': batch_size, 'shuffle': False, 'collate_fn': DataCollator() }
    test_loader = DataLoader(test_dataset, **test_loader_args)
    return test_loader

In [71]:
TEST_DATA_PATH = "./../dataset/data/test.jsonl"
MODEL_CHECKPOINT_PATH = "../checkpoints/model_checkpoint.pt"
BATCH_SIZE = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, token_to_idx, type_to_idx = load_model(MODEL_CHECKPOINT_PATH, device)
test_loader = load_test_data(TEST_DATA_PATH, token_to_idx, type_to_idx, BATCH_SIZE)
metrics, predictions = evaluate(model, test_loader, device)

print(f"\nAverage loss: {metrics['avg_loss']:.4f}")
print(f"Spacing prediction root mean squared error: {math.sqrt(metrics['space_mse']):.4f}")
print(f"Spacing prediction accuracy: {metrics['space_accuracy']:.2f}%")
print(f"Newline prediction accuracy: {metrics['newline_accuracy']:.2f}%")
print(f"Critical error rate: {metrics['critical_error_rate']:.3f}%")

print("\nNote: A critical error is defined as a wrong prediction which potentially breaks the Java code.")

  checkpoint = torch.load(checkpoint_path, map_location=device)





Unknown token ratio: 14.56%

Average loss: 0.9781
Spacing prediction root mean squared error: 1.2829
Spacing prediction accuracy: 77.77%
Newline prediction accuracy: 88.89%
Critical error rate: 0.028%

Note: A critical error is defined as a wrong prediction which potentially breaks the Java code.
