In [4]:
import requests
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, DistilBertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# ‚úÖ GPU ÏÇ¨Ïö© Ïó¨Î∂Ä ÌôïÏù∏
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("‚úÖ Ïã§Ìñâ Ïû•Ïπò:", device)

# ‚úÖ Î∞∞Ïπò ÌÅ¨Í∏∞ Ï¶ùÍ∞Ä (16 ‚Üí 32)
BATCH_SIZE = 32

# ‚úÖ Fact Check API Ï∫êÏã± Ï†ÅÏö©
fact_check_cache = {}

def check_fact_with_api(query):
    """Google Fact Check APIÎ•º Ïù¥Ïö©Ìï¥ Îâ¥Ïä§Ïùò ÏßÑÏúÑ Ïó¨Î∂ÄÎ•º ÌôïÏù∏ (Ï∫êÏã± Ï†ÅÏö©)"""
    if query in fact_check_cache:
        return fact_check_cache[query]

    params = {"query": query, "key": "AIzaSyDW8TNNxSZG2NXzA3HGK-19PDBp0jjoOu0"}
    response = requests.get("https://factchecktools.googleapis.com/v1alpha1/claims:search", params=params)
    
    if response.status_code == 200:
        data = response.json()
        claims = data.get("claims", [])
        fact_check_cache[query] = claims
        return claims

    fact_check_cache[query] = []
    return []

# ‚úÖ Îç∞Ïù¥ÌÑ∞ Î°úÎìú
df = pd.read_csv("C:/FakeNewsProject/FakeNews_py/News_Dataset.csv")
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

class NewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]

        # ‚úÖ Fact Check API Ï∫êÏã± ÌôúÏö©
        fact_check_results = check_fact_with_api(text)
        fact_check_score = len(fact_check_results)

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'fact_check_score': torch.tensor(fact_check_score, dtype=torch.float),
            'label': torch.tensor(label, dtype=torch.long)
        }

train_texts, test_texts, train_labels, test_labels = train_test_split(
    df['Content'].values, df['Label'].values, test_size=0.2, random_state=42
)

train_dataset = NewsDataset(train_texts, train_labels, tokenizer)
test_dataset = NewsDataset(test_texts, test_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# ‚úÖ GRU Ï†ÅÏö© (LSTMÎ≥¥Îã§ Îπ†Î¶Ñ)
class FakeNewsDetector(nn.Module):
    def __init__(self, hidden_dim=128, num_classes=2):
        super(FakeNewsDetector, self).__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.gru = nn.GRU(768, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2 + 1, num_classes)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, input_ids, attention_mask, fact_check_score):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        gru_out, _ = self.gru(outputs.last_hidden_state)
        out = self.dropout(gru_out[:, -1, :])
        out = torch.cat((out, fact_check_score.unsqueeze(1)), dim=1)
        return self.fc(out)

model = FakeNewsDetector().to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-5)  # ‚úÖ ÌïôÏäµÎ•† Ï¶ùÍ∞Ä

# ÏÜêÏã§ Ìï®Ïàò Î∞è ÌèâÍ∞Ä Î©îÌä∏Î¶≠ Ï†ïÏùò
criterion = nn.CrossEntropyLoss()

def train_model(model, train_loader, test_loader, optimizer, device, epochs=5):
    train_losses, test_losses = [], []
    train_accuracies, test_accuracies = [], []
    
    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0

        # üîπ ÌõàÎ†® (Training)
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            fact_check_score = batch['fact_check_score'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask, fact_check_score)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

        train_losses.append(total_loss / len(train_loader))
        train_accuracies.append(correct / total)

        # üîπ Í≤ÄÏ¶ù (Validation)
        model.eval()
        total_loss, correct, total = 0, 0, 0

        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                fact_check_score = batch['fact_check_score'].to(device)
                labels = batch['label'].to(device)

                outputs = model(input_ids, attention_mask, fact_check_score)
                loss = criterion(outputs, labels)

                total_loss += loss.item()
                correct += (outputs.argmax(1) == labels).sum().item()
                total += labels.size(0)

        test_losses.append(total_loss / len(test_loader))
        test_accuracies.append(correct / total)

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_losses[-1]:.4f}, Train Accuracy: {train_accuracies[-1]:.4f}")
        print(f"Test Loss: {test_losses[-1]:.4f}, Test Accuracy: {test_accuracies[-1]:.4f}")
        print("-" * 50)
    
    return train_losses, test_losses, train_accuracies, test_accuracies

# ‚úÖ ÌïôÏäµ Ïã§Ìñâ
train_losses, test_losses, train_accuracies, test_accuracies = train_model(
    model, train_loader, test_loader, optimizer, device, epochs=5
)

‚úÖ Ïã§Ìñâ Ïû•Ïπò: cpu


KeyboardInterrupt: 