## Paths


In [1]:
import os

CUR_DIR = os.path.abspath(os.curdir)
CACHE_PATH = os.path.join(CUR_DIR, "../.cache/")

## Load cache


In [2]:
import pickle

cache = open(os.path.join(CACHE_PATH, "final.pkl"), "rb")
data = pickle.load(cache)
cache.close()

## Initialize


In [3]:
import sys
import torch.optim as optim

# python files
sys.path.insert(0, "..")
from module import *
from vocabulary import Vocabulary
from datasets import TestingDataset
from metrics import *
import torch.nn as nn
from transforms import TextTransform
from torch.utils.data import DataLoader


# text_model = LSTMClassifier(
#     vocab_size=100000, embedding_dim=10, hidden_dim=128, num_classes=2
# )

# # models
# image_model = Resnet18(num_classes=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = Ensemble(text_model=text_model, image_model=image_model, num_classes=2)

model = UserEmbedding(
    max_num_user=100, embedding_dim=300, num_classes=2, dropout_rate=0.5
)
model = model.to(device)

# optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# metrics
total_acc = SimplePrequentialMetric()
total_loss = SimplePrequentialMetric()

# text_vocabulary
vc = Vocabulary(specials=["<pad>", "<unk>"])
vc.set_default_idx(vc["<unk>"])

# dataset
dataset = TestingDataset(data)
dataloader = DataLoader(data, batch_size=128)

# tokenizer
tokenizer = lambda x: [_.split() for _ in x]

# transform
text_transform = TextTransform(max_length=20)

In [15]:
# import wandb

# wandb.init(project="odl")

# user_vocabulary
uvc = Vocabulary(specials=["<unk>"])
uvc.set_default_idx(vc["<unk>"])
import datetime
from tqdm.auto import tqdm

for idx, batch in tqdm(enumerate(dataloader)):
    # * Text
    texts = tokenizer(batch["post_message"])

    # * User name
    user_name = batch["user_name"]

    # * Images
    images = batch["image"]

    # * Metadata
    num_like_post = batch["num_like_post"]
    num_comment_post = batch["num_comment_post"]
    num_share_post = batch["num_share_post"]
    raw_length = batch["raw_length"]
    timestamp_post = batch["timestamp_post"]
    hours = torch.tensor(
        [datetime.datetime.fromtimestamp(x.item()).hour for x in timestamp_post],
        dtype=torch.float,
    )
    weekdays = torch.tensor(
        [datetime.datetime.fromtimestamp(x.item()).weekday() for x in timestamp_post],
        dtype=torch.float,
    )

    metadata = torch.stack(
        [
            num_like_post,
            num_comment_post,
            num_share_post,
            raw_length,
            weekdays,
            hours,
        ],
        dim=1,
    )

    # * Target
    labels = batch["label"]

    uvc.append_tokens(user_name)
    print(uvc)
    for user in user_name:
        print(vc[user])
    print(user_name)
    break
    # # Transform the texts
    # vc.append_from_iterator(texts)

    # # token -> index -> padding -> tensor -> batch
    # texts = torch.stack([text_transform(vc.get_idxs(tokens)) for tokens in texts])

    # texts = texts.to(device)
    # images = images.to(device, dtype=torch.float)
    # labels = labels.to(device)

    # loss, outputs = train_one(
    #     model=model,
    #     optimizer=optimizer,
    #     criterion=criterion,
    #     inputs=(texts, images),
    #     targets=labels,
    # )

    # acc = accuracy(outputs, labels)
    # total_acc.update(acc)
    # total_loss.update(loss)
    # wandb.log({"acc": total_acc.values, "loss": total_loss.values})

0it [00:00, ?it/s]

{'<unk>': 0, '7c14dfd9cc4c03990ed7343651a06c85': 1, '095f07a3b83a6ea4b7a8fb0ff2721e6f': 2, '9678eb9c71190fc0b57d9266ae7ff5bc': 3, 'e17843f01e3eb99b2df1db64e14a3794': 4, '6039201545b96b1170abb2c79cb28c9d': 5, 'bd612a8c9a208033cd7344cba730800a': 6, '6d46f2b04842bb2221e493968cd05ca4': 7, '0adfdeb881147078b0c50354193d67dd': 8, '6e7401c76ba64a20ea93313a1e571598': 9, '158b4938e2e34f76f8b34c52a02bc663': 10, 'c543472dc0632612a27d6feab784d462': 11, '2007e8acfdb13b766efdc9e1b46fb4f7': 12, '344dc37937543bbc0dcb5c023f038cf9': 13, 'fa91f55a3dafc25e39ff1687d31423f2': 14, 'df49d91e8769e84fcc17bf38cdddec70': 15, 'fc3f2b1ca11ee3c923e005ef9d8eb212': 16, '57a2f68d3cd56978c1d3a52ff8541a96': 17, '57bd64a04c633693bcf24cdcecc7ec8a': 18, '82636849e8933b59ebab10bc26cc59fc': 19, 'fc4107af786edb367c86b5e8333ebc13': 20, '6cf14fe3f045142ad0ad38c58e26377b': 21, '16244723c53210008d58c2f9f2b14d29': 22, '4ede4a4281b6a26ff1db5f3e589b4383': 23, '88d67d0cdf3664a85ba5c42bd704618e': 24, '68d174bb739151ea2d308cc8b3732aa4': 