In [None]:
import wandb

# Initialize W&B run (if not already initialized)
run = wandb.init(project="further-testing-da", entity="daisyabbott")

# Load the dataset artifact
artifact = run.use_artifact("arcslaboratory/Multirun-testing-1K+/larger-perfect-dataset:v0")
artifact_dir = artifact.download()

# Update the dataset path
dataset_path = artifact_dir + "/data/largedata"  # Path to the extracted images from the artifact

In [None]:
import matplotlib.pyplot as plt
from fastai.vision.all import *
from fastai.callback.progress import CSVLogger
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image
import numpy as np
from pathlib import Path

In [None]:
#validation percent and num of epochs
VALID_PCT = 0.05
NUM_EPOCHS = 3 

In [None]:
# adjusted from initial raycasting file for simpler access
from pathlib import Path
current_dir = Path.cwd()
relative_path = "artifacts/larger-perfect-dataset:v0/data"
path = current_dir / relative_path

In [None]:
path.ls()

In [None]:
files = get_image_files(path)
use_pretraining = True
rgb_instead_of_gray = True 
rep = 1
model_name = "resnet18"

In [None]:
# Derived
file_prefix = "classification-" + model_name
file_prefix += '-rgb' if rgb_instead_of_gray else '-gray'
file_prefix += '-pretrained' if use_pretraining else '-notpretrained'

In [None]:
compared_models = {
    "resnet18": resnet18
}

In [None]:
# I may need to double check the vars in this
model_filename = path / f"{file_prefix}-{rep}.pkl"
print("Model relative filename :", model_filename)
log_filename = path / f"{file_prefix}-trainlog-{rep}.csv"
print("Log relative filename   :", log_filename)
print("Log relative filename   :", log_filename)
fig_filename_prefix = path / file_prefix

In [None]:
def get_fig_filename(label: str, ext: str, rep: int) -> str:
    fig_filename = f"{fig_filename_prefix}-{label}-{rep}.{ext}"
    print(label, "filename :", fig_filename)
    return fig_filename

In [None]:
def filename_to_class(filename: str) -> str:
    angle = float(filename.split("_")[1].split(".")[0].replace("p", "."))
    if angle > 0:
        return "left"
    elif angle < 0:
        return "right"
    else:
        return "forward"

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, filename_to_class, valid_pct = VALID_PCT)

In [None]:
plt.savefig(get_fig_filename("batch", "pdf", rep))

In [None]:
learn = cnn_learner(dls, compared_models[model_name], metrics=accuracy, pretrained=use_pretraining, cbs=CSVLogger(fname=log_filename))

In [None]:
learn.path

In [None]:
if use_pretraining:
    learn.fine_tune(NUM_EPOCHS)
else:
    learn.fit_one_cycle(NUM_EPOCHS)