In [21]:
import os
import plotly.express as px
import matplotlib.pyplot as plt
import os
import random
from shutil import copyfile, rmtree
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [22]:
root_path = '105_classes_pins_dataset/'

dir_names = os.listdir(root_path)
person_names = [name.split("_")[-1].title() for name in dir_names]
n_individuals = len(person_names)

print(f"Total number of individuals: {n_individuals}\n")
print(f"Name of the individuals : \n\t{person_names}")

Total number of individuals: 105

Name of the individuals : 
	['Adriana Lima', 'Alex Lawther', 'Alexandra Daddario', 'Alvaro Morte', 'Alycia Dabnem Carey', 'Amanda Crew', 'Amber Heard', 'Andy Samberg', 'Anne Hathaway', 'Anthony Mackie', 'Avril Lavigne', 'Barack Obama', 'Barbara Palvin', 'Ben Affleck', 'Bill Gates', 'Bobby Morley', 'Brenton Thwaites', 'Brian J. Smith', 'Brie Larson', 'Camila Mendes', 'Chris Evans', 'Chris Hemsworth', 'Chris Pratt', 'Christian Bale', 'Cristiano Ronaldo', 'Danielle Panabaker', 'Dominic Purcell', 'Dwayne Johnson', 'Eliza Taylor', 'Elizabeth Lail', 'Elizabeth Olsen', 'Ellen Page', 'Elon Musk', 'Emilia Clarke', 'Emma Stone', 'Emma Watson', 'Gal Gadot', 'Grant Gustin', 'Gwyneth Paltrow', 'Henry Cavil', 'Hugh Jackman', 'Inbar Lavi', 'Irina Shayk', 'Jake Mcdorman', 'Jason Momoa', 'Jeff Bezos', 'Jennifer Lawrence', 'Jeremy Renner', 'Jessica Barden', 'Jimmy Fallon', 'Johnny Depp', 'Josh Radnor', 'Katharine Mcphee', 'Katherine Langford', 'Keanu Reeves', 'Kiernen S

In [23]:
n_images_per_person = [len(os.listdir(root_path + name)) for name in dir_names]
n_images = sum(n_images_per_person)

print(f"Total Number of Images : {n_images}.")

Total Number of Images : 17534.


In [24]:
fig = px.bar(x=person_names, y=n_images_per_person, color=person_names)
fig.update_layout({'title':{'text':"Distribution of number of images per person"}})
fig.show()

In [25]:
main_directory = '105_classes_pins_dataset'
train_directory = 'trening'
test_directory = 'test'
validate_directory = 'validate'

In [26]:
def split_data(source, train, test, validate, split_size=0.8):
    
    if os.path.exists(train):
        rmtree(train)
    if os.path.exists(test):
        rmtree(test)
    if os.path.exists(validate):
        rmtree(validate)
        
    os.makedirs(train)
    os.makedirs(test)
    os.makedirs(validate)
    
    for actor_dir in os.listdir(source):
        print(actor_dir)
        actor_path = os.path.join(source, actor_dir)
        if os.path.isdir(actor_path):
            images = [image for image in os.listdir(actor_path) if image.endswith('.jpg')]
            random.shuffle(images)
            val_split_size = 0.15
            train_size = int(len(images) * (split_size - val_split_size))
            val_size = int(len(images) * split_size)
            train_images = images[:train_size]
            validate_images = images[train_size:val_size]
            test_images = images[val_size:]

            for image in train_images:
                src_path = os.path.join(actor_path, image)
                dst_path = os.path.join(train, actor_dir[5:], image)
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                copyfile(src_path, dst_path)
                
            for image in validate_images:
                src_path = os.path.join(actor_path, image)
                dst_path = os.path.join(validate, actor_dir[5:], image)
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                copyfile(src_path, dst_path)
            
            for image in test_images:
                src_path = os.path.join(actor_path, image)
                dst_path = os.path.join(test, actor_dir[5:], image)
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                copyfile(src_path, dst_path)

In [27]:
# split_data(main_directory, train_directory, test_directory, validate_directory)

In [28]:
def load_images_and_create_dataloader(directory, batch_size=32):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    dataset = ImageFolder(root=directory, transform=transform)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

In [29]:
def plot_random_image(dataloader):
    batch_iterator = iter(dataloader)
    images, labels = next(batch_iterator)

    image = images[0] 
    label = labels[0]

    image = transforms.functional.to_pil_image(image)  

    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.title(f'Label: {label}')
    plt.axis('off')  
    plt.show()

In [30]:
def save_model(model, filepath):
    torch.save(model.state_dict(), filepath)
    print("Model je uspešno sačuvan u fajlu: ", filepath)

In [31]:
def load_model(filepath):
    model = models.resnet50(pretrained=False) 
    num_classes = 105  
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(filepath))

    return model