In [None]:
import torch
from torch.utils.data import DataLoader
import pickle
import numpy as np

In [None]:
model_type = "conversation" # abc or conversation
device = "mps"

In [None]:
with open(f"../data/clean/{model_type}_train.pkl", "rb") as f:
    train = pickle.load(f)

with open(f"../data/clean/{model_type}_test.pkl", "rb") as f:
    test = pickle.load(f)

train_tensors = []
for i in train:
    train_tensors.append({
        "data": torch.tensor(i["data"]).float(),
        "label": torch.tensor(i["label"]).long()
    })

test_tensors = []
for i in test:
    test_tensors.append({
        "data": torch.tensor(i["data"]).float(),
        "label": torch.tensor(i["label"]).long()
    })

train_loader = DataLoader(train_tensors, batch_size=32, shuffle=True)
test_loader = DataLoader(test_tensors, batch_size=32, shuffle=True)

In [None]:
with open(f"../data/model/{model_type}_id2label.pkl", "rb") as f:
    id2label = pickle.load(f)

id2label

In [None]:
from PointNet import PointNet

num_classes = len(id2label)
print(f"Instantiating model with {num_classes} classes")
pointnet = PointNet(classes=num_classes, device=device)

In [6]:
from sklearn.metrics import accuracy_score

checkpoint_path = f"../data/model/{model_type}.pth"

def pointnetloss(outputs, labels, m3x3, m64x64, alpha = 0.0001):
    criterion = torch.nn.NLLLoss()
    bs=outputs.size(0)
    id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1).to(device)
    id64x64 = torch.eye(64, requires_grad=True).repeat(bs,1,1).to(device)
    diff3x3 = id3x3 - torch.bmm(m3x3,m3x3.transpose(1,2))
    diff64x64 = id64x64 - torch.bmm(m64x64,m64x64.transpose(1,2))
    return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff64x64)) / float(bs)

def train(model, train_loader, val_loader, epochs):
    for epoch in range(epochs): 
        model.train()
        running_loss = 0.0
        print_every_x = 16
        for i, data in enumerate(train_loader, 0):
            inputs = data['data'].to(device)
            labels = data['label'].to(device)
            optimizer.zero_grad()
            outputs, m3x3, m64x64 = model(inputs.transpose(1,2))

            loss = pointnetloss(outputs, labels, m3x3, m64x64)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % print_every_x == print_every_x - 1:
                print(f"[Epoch: {epoch + 1}, Batch: {i + 1}, Loss: {running_loss / print_every_x:.4f}]")
                running_loss = 0.0

        model.eval()
        # validation
        total = correct = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data['data'].to(device).float(), data['label'].to(device)
                outputs, _, _ = model(inputs.transpose(1,2))
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            print(f'Accuracy: {100 * correct / total:.2f}%')

optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.00025)
train(pointnet, train_loader, test_loader, epochs=15)
torch.save(pointnet.state_dict(), checkpoint_path)

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

pointnet = PointNet(classes=num_classes, device="cpu")
pointnet.load_from_pth(checkpoint_path)
pointnet.eval()

all_preds = []
all_labels = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        inputs, labels = data['data'].float(), data['label'].long()
        outputs, _, _ = pointnet(inputs.transpose(1,2))
        _, preds = torch.max(outputs.data, 1)
        all_preds += preds.numpy().tolist()
        all_labels += labels.numpy().tolist()
    
# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm, annot=True, fmt='.2f', xticklabels=id2label.values(), yticklabels=id2label.values())
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.show()