In [109]:
from datasets import load_dataset
import torch
from numpy.ma.core import shape

In [110]:
dataset = load_dataset(
	"hermeschen1116/emotion_transition_from_dialog",
	num_proc=16,
	keep_in_memory=True,
	trust_remote_code=True,
)

In [111]:
torch.tensor(dataset["train"][0]["user_emotion_compositions"][0]).shape

torch.Size([1, 7])

In [112]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [113]:
from src.models.libs import EmotionModel

model = EmotionModel(
		attention="dot_product"
	).to(device)

In [114]:
from src.models.libs import representation_evolute

eval_dataset = dataset["test"].map(
		lambda samples: {
			"bot_emotion_representations": [
				representation_evolute(
					model, [torch.tensor(sample[0][0]).to(device)], [torch.tensor(emotion).to(device) for emotion in sample[1]]
				)[1:]
				for sample in zip(
					samples["bot_initial_emotion_representation"],
					samples["user_emotion_compositions"],
				)
			]
		},
		batched=True,
	)

Map:   0%|          | 0/958 [00:00<?, ? examples/s]

In [115]:
torch.tensor(eval_dataset[0]["bot_emotion_representations"])

tensor([[ 1.1458,  0.5676,  0.2000,  0.2053,  0.6992,  0.5418,  0.3416],
        [ 1.3333,  0.2167,  0.3927,  0.1353,  0.6972,  0.5887,  0.0757],
        [ 1.5209, -0.1343,  0.5853,  0.0652,  0.6952,  0.6357, -0.1903],
        [ 1.7081, -0.4853,  0.7777, -0.0049,  0.6932,  0.6827, -0.4563],
        [ 1.8953, -0.8363,  0.9701, -0.0750,  0.6915,  0.7296, -0.7223]])

In [116]:
torch.tensor(eval_dataset[0]["bot_emotion_representations"]).argmax(1)

tensor([0, 0, 0, 0, 0])

In [117]:
eval_dataset = eval_dataset.map(
		lambda samples: {
			"bot_possible_emotion": [torch.tensor(sample).argmax(1) for sample in samples]
		},
		input_columns="bot_emotion_representations",
		batched=True,
		num_proc=16,
	)

Map (num_proc=16):   0%|          | 0/958 [00:00<?, ? examples/s]

In [118]:
from torch import Tensor

eval_predictions: Tensor = torch.cat([torch.tensor(turn) for turn in eval_dataset["bot_possible_emotion"]])
eval_predictions.shape

torch.Size([3040])

In [119]:
eval_truths: Tensor = torch.cat([torch.tensor(turn) for turn in eval_dataset["bot_emotion"]])
eval_truths.shape

torch.Size([3040])

In [120]:
len(eval_predictions) == len(eval_truths)

True