In [1]:
!pip install torch torchvision pandas pillow tqdm



In [2]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from tqdm import tqdm

In [3]:
!git clone https://github.com/ieee8023/covid-chestxray-dataset.git
data_dir = "covid-chestxray-dataset/images"

Cloning into 'covid-chestxray-dataset'...
remote: Enumerating objects: 3641, done.[K
remote: Total 3641 (delta 0), reused 0 (delta 0), pack-reused 3641 (from 1)[K
Receiving objects: 100% (3641/3641), 632.96 MiB | 37.87 MiB/s, done.
Resolving deltas: 100% (1450/1450), done.
Updating files: 100% (1174/1174), done.


In [6]:


# create list
image_paths = []
captions = []

for root, _, files in os.walk(data_dir):
    for file in files:
        if file.endswith(".png") or file.endswith(".jpg"):
            image_paths.append(os.path.join(root, file))
            captions.append(os.path.splitext(file)[0])  # کپشن = نام فایل بدون پسوند

# first 50 pic for test
image_paths = image_paths[:50]
captions = captions[:50]

# ==============================
# create vocab
# ==============================
all_words = set()
for cap in captions:
    for w in cap.lower().split():
        all_words.add(w)

vocab = {"<pad>":0, "<start>":1, "<end>":2, "<unk>":3}
for i, w in enumerate(all_words, start=4):
    vocab[w] = i
inv_vocab = {v:k for k,v in vocab.items()}

# ==============================
#  Dataset Class
# ==============================
class CovidXrayDataset(Dataset):
    def __init__(self, image_paths, captions, transform=None, vocab=None):
        self.image_paths = image_paths
        self.captions = captions
        self.transform = transform
        self.vocab = vocab

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        caption = self.captions[idx]
        caption_idx = [self.vocab.get("<start>")]
        for word in caption.lower().split():
            caption_idx.append(self.vocab.get(word, self.vocab["<unk>"]))
        caption_idx.append(self.vocab.get("<end>"))
        return img, torch.tensor(caption_idx)

# ==============================
# ۵️⃣ Transform و DataLoader
# ==============================
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images, 0)
    captions = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=0)
    lengths = [len(c)-1 for c in captions]  # طول بدون <start>
    return images, captions, lengths

dataset = CovidXrayDataset(image_paths, captions, transform, vocab)
loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# ==============================
# model Encoder + Decoder
# ==============================
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        for param in resnet.parameters():
            param.requires_grad = False
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.linear(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:,:-1])
        inputs = torch.cat((features.unsqueeze(1), embeddings),1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs

    def sample(self, features, max_len=20):
        sampled_ids = []
        inputs = features.unsqueeze(1)
        states = None
        for _ in range(max_len):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            predicted = outputs.argmax(1)
            sampled_ids.append(predicted.item())
            inputs = self.embed(predicted).unsqueeze(1)
        return sampled_ids

# ==============================
# create model
# ==============================
embed_size = 256
hidden_size = 512
vocab_size = len(vocab)

encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# ==============================
# train with pack_padded_sequence
# ==============================
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(list(decoder.parameters()) + list(encoder.linear.parameters()), lr=1e-3)

for epoch in range(2):
    for images, captions, lengths in loader:
        outputs = decoder(encoder(images), captions)

        packed_outputs = pack_padded_sequence(outputs, lengths, batch_first=True, enforce_sorted=False)
        packed_targets = pack_padded_sequence(captions[:,1:], lengths, batch_first=True, enforce_sorted=False)

        loss = criterion(packed_outputs.data, packed_targets.data)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} done, Loss={loss.item():.4f}")

# ==============================
# test
# ==============================
test_img, _ = dataset[0]  # only 2 sample return
with torch.no_grad():
    feature = encoder(test_img.unsqueeze(0))
    sampled_ids = decoder.sample(feature)

sampled_caption = []
for idx in sampled_ids:
    word = inv_vocab.get(idx, "<unk>")
    if word=="<end>":
        break
    sampled_caption.append(word)
print("Generated Caption:", " ".join(sampled_caption))


Epoch 1 done, Loss=2.1619
Epoch 2 done, Loss=2.0084
Generated Caption: 16689_1_6
