In [None]:
import torch
import pandas as pd
import easyocr
from PIL import Image
from torchvision import transforms
import clip
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
df = pd.read_json('vimmsd-warmup.json')
df = df.T
df.head()

label_mapping = {
    'not-sarcasm': 0,
    'image-sarcasm': 1,
    'text-sarcasm': 2,
    'multi-sarcasm': 3
}

df['label'] = df['label'].map(label_mapping)
df.head()

Unnamed: 0,image,caption,label
464,724743746f3fe695cd93cab67abf47f31348dd46e1d6e8...,Biển miền Trung nước đẹp nhỉ,3
7413,92d5d63ece4471fa20fda5a504b841f17eaee8172de711...,Chắc là nắc cụt rồi\n#phetphaikhong,3
3808,abadbf508db12242d4f00f69ac690305e91dc5d8ad0c07...,Nhiều khi ta muốn ta được thiếu nợ\nĐể khi đi ...,0
5816,84a61e90daadb2297888d685299e25a00a03a91515059b...,"Phi công này 1 người lái thôi, ai đụng vào là ...",3
1632,cf50dca40e9196eb443a4b33db60c17c5ad2da69726aab...,Ủy ban Nhân dân thành phố Đà Nẵng vừa có văn b...,0


Lấy text từ image sử dụng easyOCR (language: vietnam)

In [4]:
reader = easyocr.Reader(['vi'], gpu=True)


In [5]:
tokenizer = AutoTokenizer.from_pretrained('vinai/phobert-base')
phobert_model = AutoModel.from_pretrained('vinai/phobert-base')

clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [6]:
def extract_text_from_image(image_path, phase):
    image_path = "./image/" + phase + '/' + image_path
    result = reader.readtext(image_path)
    text = ' '.join([res[1] for res in result])
    return text

def encode_text(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = phobert_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1)

def concat_tensors(row):
    return torch.cat((row['encoded_caption'], row['encoded_extracted_text']), dim=1)

def encode_image(image_path, phase):
    image_path = "./image/" + phase + "/" + image_path
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)  # Tiền xử lý ảnh
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
    return image_features

In [7]:
df['encoded_caption'] = df['caption'].apply(encode_text)
df['extracted_text'] = df['image'].apply(lambda x: extract_text_from_image(x, 'warmup'))
df['encoded_extracted_text'] = df['extracted_text'].apply(encode_text)
df['combined_text'] = df.apply(concat_tensors, axis=1)
df['encoded_image'] = df['image'].apply(lambda x: encode_image(x, 'warmup'))

df.head()

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Unnamed: 0,image,caption,label,encoded_caption,extracted_text,encoded_extracted_text,combined_text,encoded_image
464,724743746f3fe695cd93cab67abf47f31348dd46e1d6e8...,Biển miền Trung nước đẹp nhỉ,3,"[[tensor(-0.3261), tensor(0.1561), tensor(-0.0...","VỢ CHỔNG THUỶ TIÊN, CÔNG VINH ĐI CHOI PHÚ QUỐC...","[[tensor(-0.2914), tensor(0.2022), tensor(-0.1...","[[tensor(-0.3261), tensor(0.1561), tensor(-0.0...","[[tensor(-0.2019, device='cuda:0', dtype=torch..."
7413,92d5d63ece4471fa20fda5a504b841f17eaee8172de711...,Chắc là nắc cụt rồi\n#phetphaikhong,3,"[[tensor(-0.1268), tensor(-0.0320), tensor(-0....",Jz mấy má? 29 phút trước Học lý 12 không hiểu ...,"[[tensor(-0.1801), tensor(0.1206), tensor(-0.0...","[[tensor(-0.1268), tensor(-0.0320), tensor(-0....","[[tensor(-0.1495, device='cuda:0', dtype=torch..."
3808,abadbf508db12242d4f00f69ac690305e91dc5d8ad0c07...,Nhiều khi ta muốn ta được thiếu nợ\nĐể khi đi ...,0,"[[tensor(-0.1112), tensor(0.0278), tensor(-0.0...",,"[[tensor(0.1059), tensor(0.6111), tensor(-0.26...","[[tensor(-0.1112), tensor(0.0278), tensor(-0.0...","[[tensor(-0.0428, device='cuda:0', dtype=torch..."
5816,84a61e90daadb2297888d685299e25a00a03a91515059b...,"Phi công này 1 người lái thôi, ai đụng vào là ...",3,"[[tensor(0.0312), tensor(0.0544), tensor(-0.21...",ĐÃ BlẾT ĐUỢC LÍ DO LỆ KWEEN KHÔNG UA TRANG PARIS,"[[tensor(-0.0723), tensor(0.1501), tensor(-0.2...","[[tensor(0.0312), tensor(0.0544), tensor(-0.21...","[[tensor(-0.0614, device='cuda:0', dtype=torch..."
1632,cf50dca40e9196eb443a4b33db60c17c5ad2da69726aab...,Ủy ban Nhân dân thành phố Đà Nẵng vừa có văn b...,0,"[[tensor(-0.0825), tensor(0.1394), tensor(-0.0...",Đà Nẵng: Cẩu Rông dừng phun cầu sông Hàn không...,"[[tensor(-0.0106), tensor(0.2360), tensor(-0.2...","[[tensor(-0.0825), tensor(0.1394), tensor(-0.0...","[[tensor(-0.1697, device='cuda:0', dtype=torch..."


In [8]:
class ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = transforms.Compose([
            transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    def __call__(self, img):
        return self.data_transform(img)

In [9]:
resize = 512
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

image_transform = ImageTransform(resize, mean, std)

In [10]:
class SarcasmDataset(Dataset): 
    def __init__(self, dataframe):
        self.dataframe = dataframe
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        ret = {
            'image_features': row['encoded_image'].float(),
            'caption_features': row['encoded_caption'].float(),
            'extracted_text_features': row['encoded_extracted_text'],
            'label': row['label']
        }
        
        return ret

# Tạo dataset

In [11]:
df.head()

Unnamed: 0,image,caption,label,encoded_caption,extracted_text,encoded_extracted_text,combined_text,encoded_image
464,724743746f3fe695cd93cab67abf47f31348dd46e1d6e8...,Biển miền Trung nước đẹp nhỉ,3,"[[tensor(-0.3261), tensor(0.1561), tensor(-0.0...","VỢ CHỔNG THUỶ TIÊN, CÔNG VINH ĐI CHOI PHÚ QUỐC...","[[tensor(-0.2914), tensor(0.2022), tensor(-0.1...","[[tensor(-0.3261), tensor(0.1561), tensor(-0.0...","[[tensor(-0.2019, device='cuda:0', dtype=torch..."
7413,92d5d63ece4471fa20fda5a504b841f17eaee8172de711...,Chắc là nắc cụt rồi\n#phetphaikhong,3,"[[tensor(-0.1268), tensor(-0.0320), tensor(-0....",Jz mấy má? 29 phút trước Học lý 12 không hiểu ...,"[[tensor(-0.1801), tensor(0.1206), tensor(-0.0...","[[tensor(-0.1268), tensor(-0.0320), tensor(-0....","[[tensor(-0.1495, device='cuda:0', dtype=torch..."
3808,abadbf508db12242d4f00f69ac690305e91dc5d8ad0c07...,Nhiều khi ta muốn ta được thiếu nợ\nĐể khi đi ...,0,"[[tensor(-0.1112), tensor(0.0278), tensor(-0.0...",,"[[tensor(0.1059), tensor(0.6111), tensor(-0.26...","[[tensor(-0.1112), tensor(0.0278), tensor(-0.0...","[[tensor(-0.0428, device='cuda:0', dtype=torch..."
5816,84a61e90daadb2297888d685299e25a00a03a91515059b...,"Phi công này 1 người lái thôi, ai đụng vào là ...",3,"[[tensor(0.0312), tensor(0.0544), tensor(-0.21...",ĐÃ BlẾT ĐUỢC LÍ DO LỆ KWEEN KHÔNG UA TRANG PARIS,"[[tensor(-0.0723), tensor(0.1501), tensor(-0.2...","[[tensor(0.0312), tensor(0.0544), tensor(-0.21...","[[tensor(-0.0614, device='cuda:0', dtype=torch..."
1632,cf50dca40e9196eb443a4b33db60c17c5ad2da69726aab...,Ủy ban Nhân dân thành phố Đà Nẵng vừa có văn b...,0,"[[tensor(-0.0825), tensor(0.1394), tensor(-0.0...",Đà Nẵng: Cẩu Rông dừng phun cầu sông Hàn không...,"[[tensor(-0.0106), tensor(0.2360), tensor(-0.2...","[[tensor(-0.0825), tensor(0.1394), tensor(-0.0...","[[tensor(-0.1697, device='cuda:0', dtype=torch..."


In [12]:
# Tạo dataset
train_dataset = SarcasmDataset(df)

# Tạo Dataloader

In [13]:
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Networking

In [23]:
class SarcasmClassifier(nn.Module):
    def __init__(self):
        super(SarcasmClassifier, self).__init__()
        self.fc_image = nn.Linear(512, 512)
        self.fc_caption = nn.Linear(768, 512)
        self.fc1 = nn.Linear(512 * 2, 256)  # Combine both inputs
        self.fc2 = nn.Linear(256, 4)  # 4 classification labels
        
        # Additional layers for better performance
        self.dropout = nn.Dropout(0.5)
        self.batch_norm1 = nn.BatchNorm1d(512)
        self.batch_norm2 = nn.BatchNorm1d(256)

    def forward(self, image_features, caption_features):
        image_features = F.relu(self.fc_image(image_features))
        image_features = self.batch_norm1(image_features)
        
        caption_features = F.relu(self.fc_caption(caption_features))
        caption_features = self.batch_norm1(caption_features)
        
        combined_features = torch.cat((image_features, caption_features), dim=1)
        combined_features = F.relu(self.fc1(combined_features))
        combined_features = self.batch_norm2(combined_features)
        combined_features = self.dropout(combined_features)
        
        output = self.fc2(combined_features)
        return output

# Training

In [15]:
# Khởi tạo mô hình, loss function và optimizer
model = SarcasmClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

num_epochs = 5
# Huấn luyện mô hình
for epoch in range(num_epochs):  # Huấn luyện trong 5 epoch
    model.train()
    epoch_loss = 0.0

    for batch in tqdm(train_dataloader):
        image_features = batch['image_features'].to(device)
        caption_features = batch['caption_features'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        
        outputs = model(image_features, caption_features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/5], Loss: {epoch_loss / len(train_dataloader)}")

100%|██████████| 4/4 [00:00<00:00, 64.35it/s]


Epoch [1/5], Loss: 1.1972186863422394


100%|██████████| 4/4 [00:00<00:00, 179.34it/s]


Epoch [2/5], Loss: 0.7683773040771484


100%|██████████| 4/4 [00:00<00:00, 160.29it/s]


Epoch [3/5], Loss: 0.613918624818325


100%|██████████| 4/4 [00:00<00:00, 186.72it/s]


Epoch [4/5], Loss: 0.47284547984600067


100%|██████████| 4/4 [00:00<00:00, 181.77it/s]

Epoch [5/5], Loss: 0.3723446913063526





# Evaluate

In [16]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in train_dataloader:
        image_features = batch['image_features'].to(device)
        caption_features = batch['caption_features'].to(device)
        labels = batch['label'].to(device)

        outputs = model(image_features, caption_features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')

Accuracy: 94.06%


In [17]:
torch.save(model.state_dict(), 'sarcasm_classifier.pth')

In [18]:
model = SarcasmClassifier().to(device)
model.load_state_dict(torch.load('sarcasm_classifier.pth'))
model.eval()

  model.load_state_dict(torch.load('sarcasm_classifier.pth'))


SarcasmClassifier(
  (fc_image): Linear(in_features=512, out_features=512, bias=True)
  (fc_caption): Linear(in_features=768, out_features=512, bias=True)
  (fc1): Linear(in_features=1024, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=4, bias=True)
)

In [19]:
test_df = pd.read_json('vimmsd-warmup.json')

test_df = test_df.T

test_df['encoded_caption'] = test_df['caption'].apply(encode_text)
test_df['extracted_text'] = test_df['image'].apply(lambda x: extract_text_from_image(x, 'warmup'))
test_df['encoded_extracted_text'] = test_df['extracted_text'].apply(encode_text)
test_df['combined_text'] = test_df.apply(concat_tensors, axis=1)
test_df['encoded_image'] = test_df['image'].apply(lambda x: encode_image(x, 'warmup'))

test_dataset = SarcasmDataset(test_df)
test_dataloader = DataLoader(test_dataset, batch_size=32)

print(len(test_df))

101


In [20]:
label_map = {0: "not-sarcasm", 1: "image-sarcasm", 2: "text-sarcasm", 3: "multi-sarcasm"}
predictions = []
with torch.no_grad():
    for batch in tqdm(test_dataloader):
        image_features = batch['image_features'].to(device)
        caption_features = batch['caption_features'].to(device)
        
        outputs = model(image_features, caption_features)
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        
predictions = [label_map[p] for p in predictions]



100%|██████████| 4/4 [00:00<00:00, 196.66it/s]
