# Basic CNN model


In [1]:
# NOTE: I didn't have much time so threw this together in pytorch, we can refactor if necessary
import os
from datetime import datetime

from PIL import Image
import numpy as np
import pyarrow as pa
from pyarrow import parquet

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torchinfo import summary

from sklearn.metrics import classification_report

photo_directory = "data/clean/reduced_photos"

device = (
    torch.accelerator.current_accelerator().type
    if torch.accelerator.is_available()
    else "cpu"
)

In [None]:
# NOTE: could be moved to clean script
b = parquet.read_table("data/clean/reduced_business_details.parquet")
p = parquet.read_table("data/clean/reduced_photo_details.parquet")
photos_scores = p.join(b, "business_id").combine_chunks()

photos_scores = photos_scores.add_column(
    0,
    "star_category",
    pa.array(
        photos_scores.select(["stars"])
        .to_pandas()
        .stars.apply(
            lambda r:  (0. if r <= 3 else 1.) if r <= 4 else 2.)
    ),
)

In [None]:
# TODO: image normalization - I think we need a real transform at some point?
class ImageData(Dataset):
    def __init__(
        self, labels, filenames, width=256, height=256, transform=None
    ) -> None:
        super().__init__()

        self.width = width
        self.height = height
        self.transform = transform
        self.labels = labels
        self.filenames = filenames

    def __getitem__(self, index):
        img = Image.open(
            f"{photo_directory}/{self.filenames[index]}.jpg"
        )  # use pillow to open a file
        img = img.resize((self.width, self.height))  # resize the file to 256x256
        img = img.convert("RGB")  # convert image to RGB channel
        if self.transform is not None:
            img = self.transform(img)

        img = np.asarray(
            img
        ).transpose(
            -1, 0, 1
        )  # we have to change the dimensions from width x height x channel (WHC) to channel x width x height (CWH)
        img = img / 255
        img = (
            torch.from_numpy(np.asarray(img)).to(torch.float32).to(device)
        )  # create the image tensor
        label = (
            nn.functional.one_hot(
                torch.from_numpy(np.asarray(self.labels[index]))
                .to(torch.long)
                .to(device)
                , num_classes=3
            )
        )
        return img, label

    def __len__(self):
        return len(self.labels)


yelp_photos = ImageData(
    photos_scores.select(["star_category"]).to_pandas().star_category,
    photos_scores.select(["photo_id"]).to_pandas().photo_id,
)
# NOTE: change batch size for actual leanring

train, val, test = torch.utils.data.random_split(yelp_photos, [0.5, 0.3, 0.2])
training_loader = DataLoader(train, batch_size=30, shuffle=True)
test_loader = DataLoader(test, batch_size=len(test))
validation_loader = DataLoader(val, batch_size=len(val))

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.convolutional = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 132, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.lin = nn.Sequential(
            nn.Linear(19008, 1500),
            nn.ReLU(),
            nn.Linear(1500, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10),
            nn.ReLU(),
            nn.Linear(10, 3),
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        return self.lin(torch.flatten(self.convolutional(x), 1))


epoch_number = 1
best_vloss = 10000
basic_cnn = CNN().to(device)
opt = torch.optim.Adam(basic_cnn.parameters(), lr = .001)
loss_fn = torch.nn.BCEWithLogitsLoss() 
print(summary(basic_cnn))
basic_cnn(yelp_photos.__getitem__(1)[0].unsqueeze(0))

In [None]:
# TODO: export model WITHIN LOOP
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('logs/basic_cnn_{}'.format(timestamp))

def train_one_epoch(epoch_index, m: nn.Module, loader, o, lf, writer=writer):
    running_loss = 0.
    
    for i, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device).to(torch.float32), labels.to(device).to(torch.float32).squeeze()
        
        o.zero_grad()
        outputs = m(inputs)
        loss = lf(outputs, labels)
        loss.backward()
        opt.step()
        
        running_loss += loss.item()
    

    return m, running_loss / i

for epoch in range(31):
    print(f"EPOCH {epoch_number}")
    basic_cnn.train(True)
    basic_cnn, tloss = train_one_epoch(epoch, basic_cnn, training_loader, opt, loss_fn, writer)
    
    basic_cnn.eval()

    # TODO: add val loss check
    with torch.no_grad():
        inputs, labels = next(iter(validation_loader))
        inputs, labels = inputs.to(device).to(torch.float32), labels.to(device).to(torch.float32).squeeze()

        outputs = basic_cnn(inputs)
        vloss = loss_fn(outputs, labels).item()

    print(f"LOSS train {round(tloss, 5)}, validation {round(vloss, 5)}")
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : tloss, 'Validation' : vloss },
                    epoch_number + 1)
    writer.flush()
    if vloss < best_vloss:
        best_vloss = vloss
        model_path = f"models/basic_cnn/{epoch_number}_{timestamp}_{round(vloss, 3)}"
        torch.save(basic_cnn.state_dict(), model_path)

    epoch_number += 1

In [None]:
# TODO: sklearn for scoring
saved_models = os.listdir("models/basic_cnn")
best_model = [it for it in saved_models 
              if str(min([float(i.split("_")[-1]) for 
              i in os.listdir("models/basic_cnn")])) in it][0]

for i, (inputs, labels) in enumerate(test_loader):
    outputs = basic_cnn(inputs)
    print(str(classification_report(labels.argmax(dim=1).cpu().detach().numpy(), 
                                    outputs.argmax(dim=1).cpu().detach().numpy())))