In [61]:
import PIL
import os
import tensorflow as tf
import numpy as np
import pandas
from skimage.io import imread
from tensorflow.keras import layers
from tensorflow.python.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

In [42]:
directory_path = 'data/images/training_data'

In [52]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory_path,
    image_size=(256, 256),  # Resize images to 256x256
    batch_size=32,  # Number of images to return in each batch
    label_mode='categorical',  # Can be 'int', 'categorical', or 'binary'
    validation_split=0.2,
    subset='training',
    seed=42
)

Found 46 files belonging to 3 classes.
Using 37 files for training.


In [53]:
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory_path,
    image_size=(256, 256),  # Resize images to 256x256
    batch_size=32,  # Number of images to return in each batch
    label_mode='categorical',  # Can be 'int', 'categorical', or 'binary'
    validation_split=0.2,
    subset='validation',
    seed=42
)

Found 46 files belonging to 3 classes.
Using 9 files for validation.


In [54]:
train_ds.class_names

['cringe', 'funny', 'neutral']

In [62]:
resnet_model = Sequential()

pretrained_model = tf.keras.applications.ResNet50(
    include_top=False,
    input_shape=(256,256,3),
    pooling='avg',
    classes=3
)
 
for layer in pretrained_model.layers:
    layer.trainable = False
    
resnet_model.add(pretrained_model)
resnet_model.add(Flatten())
resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(3, activation='softmax'))

In [63]:
resnet_model.summary()

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 module_wrapper (ModuleWrap  (None, 2048)              0         
 per)                                                            
                                                                 
 module_wrapper_1 (ModuleWr  (None, 512)               1049088   
 apper)                                                          
                                                                 
 module_wrapper_2 (ModuleWr  (None, 3)                 1539      
 apper)                                                          
                                                                 
Total params: 24638339 (93.99 MB)
Trainable params: 1050627 (4.01 MB)
Non-trainable params: 23587712 (89.98 MB)
________

In [64]:
resnet_model.compile(
    optimizer=Adam(lr=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)



In [66]:
history = resnet_model.fit(
    train_ds,
    validation_data=test_ds,
    epochs=10
)

