In [51]:
import torch
import multiprocessing
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
import os
import requests
import json 
from tqdm import tqdm
import time
from torch.autograd import Variable
ImageFile.LOAD_TRUNCATED_IMAGES = True
import copy
import numpy as np
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'comic_dataset'

In [29]:
#TOTAL_POSTS = 7400200 # Total number of posts on Danbooru
IMG_PER_BATCH = 200 # Read limit from Danbooru
TAGS = ['action', 'looking_at_another', 'romance', 'sad', 'crying', 'angry', 'scared', 'surprised', 'fighting', 'chase', 'talking', 'couple'] # Tags for our categories
#STOP_ID = 7400200
NUM_TAGS = len(TAGS)
## TEMP ##
TOTAL_POSTS = 3727400
JSON_FILE = 'comic_labels.json'
TRAIN_JSON = 'comic_labels_train.json'
VAL_JSON = 'comic_labels_val.json'
GPU_MODE = 0
MPS_MODE = 1
CUDA_DEVICE = 0
NUM_EPOCHS = 5
BASE_LR = 0.001
DECAY_WEIGHT = 0.1 
EPOCH_DECAY = 30 
BATCH_SIZE = 10

In [46]:
# Creates directory
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# Returns the tags from TAGS that are in the tag_string of a post; returns False if none of our TAGS are in tag_string
def check_tags_in_tag_string(tag_string, tags):
    tag_list = []
    for tag in tags:
        if tag in tag_string:
            tag_list.append(tag)
    if len(tag_list) == 0:
        return False
    else:
        return tag_list
    
# Copied from custom_hymenoptera_dataset.py
# Checks for valid image files (size and extension)
def is_valid_image_file(filename, max_pixels=178956970):
    # Check file name extension
    valid_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
    if os.path.splitext(filename)[1].lower() not in valid_extensions:
        print(f"Invalid image file extension \"{filename}\". Skipping this file...")
        return False
    
    # Temporarily disable the decompression bomb check
    original_max_image_pixels = Image.MAX_IMAGE_PIXELS
    Image.MAX_IMAGE_PIXELS = None
    
    # Verify that image file is intact and check its size
    try:
        with Image.open(filename) as img:
            img.verify()  # Verify if it's an image
            # Restore the original MAX_IMAGE_PIXELS limit
            Image.MAX_IMAGE_PIXELS = original_max_image_pixels
            
            # Check image size without loading the image into memory
            if img.size[0] * img.size[1] > max_pixels:
                print(f"Image {filename} is too large. Skipping this file...")
                return False
            return True
    except (IOError, SyntaxError) as e:
        print(f"Invalid image file {filename}: {e}")
        # Ensure the MAX_IMAGE_PIXELS limit is restored even if an exception occurs
        Image.MAX_IMAGE_PIXELS = original_max_image_pixels
        return False
    # Ensure the MAX_IMAGE_PIXELS limit is restored in case of any other unexpected exit
    Image.MAX_IMAGE_PIXELS = original_max_image_pixels
    
# Create dir for our comic images
images_dir = '../comic_images'
create_dir(images_dir)

# Dict for image file name and list of tags
image_label_dict = {}

B = IMG_PER_BATCH

In [5]:
progress_bar = tqdm(total=TOTAL_POSTS)

# Loops over all post id's to download images from posts that are tagged "comic" and contain at least one of our TAGS
# Also creates dict of image file name and associated tags
while B <= TOTAL_POSTS:
    url = f'https://danbooru.donmai.us/posts.json?page=b{B}&page=a{B-IMG_PER_BATCH}&limit={IMG_PER_BATCH}'
    response_pages = requests.get(url)

    response_pages_json = response_pages.json()

    for page in response_pages_json:
        if 'file_url' in page and 'tag_string' in page:
            tag_string = page['tag_string']
            id = page['id']

            scene_tags = check_tags_in_tag_string(tag_string, TAGS)
            if 'comic' in page['tag_string'] and scene_tags:
                file_url = page['file_url']

                image_path = f'{id}.jpg'
        
                if not os.path.exists(os.path.join(images_dir, image_path)):
                    response_img = requests.get(file_url)

                    # If post contains relevant tags and has valid image file, save the image with id as name
                    if response_img.status_code == 200:
                        
                        with open(os.path.join(images_dir, image_path), 'wb') as file:
                            file.write(response_img.content)
                        
                # Write values to dict
                image_label_dict[image_path] = scene_tags

    B += IMG_PER_BATCH
    progress_bar.update(200)

progress_bar.close()

print(image_label_dict)

  0%|          | 0/3727400 [00:00<?, ?it/s]

  0%|          | 5000/3727400 [00:11<2:25:41, 425.82it/s]

In [6]:
# Save dict to .json file
with open(JSON_FILE, 'w') as f: 
     json.dump(image_label_dict, f)

In [2]:
import random

# Break data into train and val .json files
with open(JSON_FILE, 'r') as f:
    data = json.load(f)

    total_items = len(data)
    val_part_size = int(0.2*total_items)
    train_part_size = total_items-val_part_size

    keys = list(data.keys())

    random.shuffle(keys)

    val_part_keys = keys[:val_part_size]
    train_part_keys = keys[val_part_size:]

    val_part_dict = {key: data[key] for key in val_part_keys}
    train_part_dict = {key: data[key] for key in train_part_keys}

    with open(TRAIN_JSON, 'w') as t:
        json.dump(train_part_dict, t)

    with open(VAL_JSON, 'w') as w:
        json.dump(val_part_dict, w)

In [None]:
class ComicDataset(Dataset):
    # Hot encode our labels for our targets
    def hot_encode_target(self, tags):
        target = torch.zeros(NUM_TAGS)
        for tag in tags:
            target[TAGS.index(tag)] = 1

        return target

    def __init__(self, images_dir, json_file, transform=None, target_transform=None):
        self.images_dir = images_dir
        self.transform = transform
        self.target_transform = target_transform

        image_label_dict = {}
        class_counts = {}

        for tag in TAGS:
            class_counts[tag] = 0

        with open(json_file, 'r') as f:
            data = json.load(f)

        for key, value in data.items():
            if key in os.listdir(images_dir):
                if is_valid_image_file(os.path.join(self.images_dir, key)):
                    target = self.hot_encode_target(value)

                    if len(target) == NUM_TAGS:
                        image_label_dict[key] = target
                        for v in value:
                            class_counts[v] += 1

                    else:
                        print('Invalid file: ' + key + '. Skipping this file...')

        self.items = list(image_label_dict.items())
        print('Class counts: ', class_counts)

        if (sum(class_counts.values()) > 23000):
            phase = "TRAIN"
        else:
            phase = "VAL"
        print(f"{phase.upper()} SET STATISTICS:")
        total_images = sum(class_counts.values())
        print(f"Total images: {total_images}")
        for class_id, count in class_counts.items():
            print(f"Class {class_id}: {count} images")
        print("\n")

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.items[idx][0])
        image = Image.open(img_path)
        label = self.items[idx][1]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label

In [42]:
use_gpu = GPU_MODE
use_mps = MPS_MODE
if use_gpu:
    torch.cuda.set_device(CUDA_DEVICE)

if use_mps:
   mps_device = torch.device("mps")

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [47]:
def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=100):
    since = time.time()

    best_model = model
    best_loss = torch.inf

    losses = {'train': [], 'val': []}

    for epoch in range(num_epochs):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                optimizer = lr_scheduler(optimizer, epoch)
                model.train()
            else:
                model.eval()
        
            running_loss = 0.0
            running_corrects = 0

            counter = 0

            for data in dset_loaders[phase]:
                inputs, labels = data

                if use_gpu:
                    try:
                        inputs, labels = Variable(inputs.float().cuda()), Variable(labels.long().cuda())

                    except Exception as e:
                        print("ERROR! here are the inputs and labels before we print the full stack trace:")
                        print(inputs, labels)
                        raise e
                    
                elif use_mps:
                   try:
                      inputs, labels = Variable(inputs.float().to(mps_device)), Variable(labels.long().to(mps_device))

                   except Exception as e:
                      print("ERROR! here are the inputs and labels before we print the full stack trace:")
                      print(inputs, labels)
                      raise e
                
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                
                optimizer.zero_grad()
                outputs = model(inputs)

                loss = criterion(outputs, labels)

                if counter%100 == 0:
                    print('Reached batch iteration', counter)

                counter += 1

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                try:
                    running_loss += loss.item()
                except:
                    print('unexpected error, could not calculate loss or do a sum.')

            epoch_loss = running_loss / dset_sizes[phase]
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            losses[phase].append(epoch_loss)

            # deep copy the model
            if phase == 'val':
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model = copy.deepcopy(model)
                    print('new best loss =', best_loss)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    print('returning and looping back')

    return best_model, losses

def exp_lr_scheduler(optimizer, epoch, init_lr=BASE_LR, lr_decay_epoch=EPOCH_DECAY):
    """Decay learning rate by a factor of DECAY_WEIGHT every lr_decay_epoch epochs."""
    lr = init_lr * (DECAY_WEIGHT**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

In [52]:
if __name__ == '__main__':

    try:
        multiprocessing.set_start_method('spawn')
    except RuntimeError as e:
        print("RuntimeError:", e)

    dsets = {}

    dsets['train'] = ComicDataset(images_dir, TRAIN_JSON, data_transforms['train'])
    dsets['val'] = ComicDataset(images_dir, VAL_JSON, data_transforms['val'])

    dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

    print('Finished making datasets!')

    dset_loaders = {}
    for split in ['train', 'val']:
        dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=BATCH_SIZE, shuffle=True, num_workers=12)

    class ResNet50MultiLabel(nn.Module):
        def __init__(self, num_classes):
            super(ResNet50MultiLabel, self).__init__()
            # Load a pre-trained ResNet-50 model
            self.resnet50 = models.resnet50(pretrained=True)
            
            # Replace the final fully connected layer
            # ResNet-50's fc layer output features is 2048
            self.resnet50.fc = nn.Linear(2048, num_classes)
        
        def forward(self, x):
            return self.resnet50(x)

    model = ResNet50MultiLabel(NUM_TAGS)
    criterion = nn.BCEWithLogitsLoss()

    if use_gpu:
        criterion.cuda()
        model.cuda()

    if use_mps:
        criterion.to(mps_device)
        model.to(mps_device)

    optimizer = optim.Adam(model.parameters(), lr = 0.001)

    model, losses = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=NUM_EPOCHS)

    for split in ['train', 'val']:
        print(split, 'losses by epoch:', losses[split])

    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(dsets['train']), size=(1,)).item()
        img, label = dsets['train'][sample_idx]

        # Convert the tensor image to numpy
        img = img.numpy().transpose((1, 2, 0))  # Change from (C, H, W) to (H, W, C)
        
        # Undo the normalization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean  # Apply the inverse of the initial normalization
        img = np.clip(img, 0, 1)  # Ensure the values are between 0 and 1

        # Plot the image
        figure.add_subplot(rows, cols, i)
        plt.title(str(label))
        plt.axis("off")
        plt.imshow(img)  # img is now in the correct format for imshow

    plt.show()
    plt.savefig('train_images.png')


    def plot_training_history(losses):
        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 2)
        for phase in ['train', 'val']:
            plt.plot(losses[phase], label=f'{phase} loss')
        plt.title('Loss over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.show()
        plt.savefig('loss_plot.png')

    plot_training_history(losses)
    torch.save(model.state_dict(), 'fine_tuned_best_model.pt')

----------
Epoch 0/4
----------
LR is set to 0.001


PicklingError: Can't pickle <class '__main__.ComicDataset'>: it's not the same object as __main__.ComicDataset