In [1]:
import wandb

# Initialize W&B run (if not already initialized)
run = wandb.init(
    project="first-testing-refactored",
    entity="daisyabbott",
    notes="A set of small/useless datasets for testing.",
    job_type="dataset-upload"
)
# Load the dataset artifact
artifact = run.use_artifact("arcslaboratory/Multirun-testing-1K+/larger-perfect-dataset:v0")
artifact_dir = artifact.download()

# Update the dataset path
dataset_path = artifact_dir + "/data/largedata"  # Path to the extracted images from the artifact

[34m[1mwandb[0m: Currently logged in as: [33mdaisyabbott[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact larger-perfect-dataset:v0, 893.10MB. 1852 files... 
[34m[1mwandb[0m:   1852 of 1852 files downloaded.  
Done. 0:0:0.3


In [2]:
from argparse import ArgumentParser

In [3]:
import matplotlib.pyplot as plt
from fastai.vision.all import *
from fastai.callback.progress import CSVLogger
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image
import numpy as np
from pathlib import Path

In [4]:
#validation percent and num of epochs
VALID_PCT = 0.05
NUM_EPOCHS = 3
NUM_REPLICATES = 2

In [5]:
# correct server num (using 4th for now, eventually switch to 3rd (2)and share w chau)
torch.cuda.set_device(3) 
print("Running on GPU: " + str(torch.cuda.current_device()))

Running on GPU: 3


In [6]:
# adjusted from initial raycasting file for simpler access
from pathlib import Path
current_dir = Path.cwd()
relative_path = "artifacts/larger-perfect-dataset:v0/data"
path1 = current_dir / relative_path

In [7]:
path1.ls()

(#1854) [Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000000_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000015_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000060_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000226_+0p70.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000392_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000075_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000241_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000135_0.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000196_-0p88.png'),Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data/000346_0.png')...]

In [8]:
files = get_image_files(path1)
use_pretraining = True
rgb_instead_of_gray = True 
rep = 1
model_name = "resnet18"

In [9]:
file_prefix = "classification-" + model_name
file_prefix += '-rgb' if rgb_instead_of_gray else '-gray'
file_prefix += '-pretrained' if use_pretraining else '-notpretrained'

In [10]:
compared_models = {
    "resnet18": resnet18,
    # "resnet34": resnet34
}

In [11]:
# I may need to double check the vars in this
model_filename = path1 / f"{file_prefix}-{rep}.pkl"
print("Model relative filename :", model_filename)
log_filename = path1 / f"{file_prefix}-trainlog-{rep}.csv"
print("Log relative filename   :", log_filename)
print("Log relative filename   :", log_filename)
fig_filename_prefix = path1 / file_prefix

Model relative filename : /home/dcad2021/artifacts/larger-perfect-dataset:v0/data/classification-resnet18-rgb-pretrained-1.pkl
Log relative filename   : /home/dcad2021/artifacts/larger-perfect-dataset:v0/data/classification-resnet18-rgb-pretrained-trainlog-1.csv
Log relative filename   : /home/dcad2021/artifacts/larger-perfect-dataset:v0/data/classification-resnet18-rgb-pretrained-trainlog-1.csv


In [12]:
def get_fig_filename(prefix: str, label: str, ext: str, rep: int) -> str:
    fig_filename = f"{prefix}-{label}-{rep}.{ext}"
    print(label, "filename :", fig_filename)
    return fig_filename

In [13]:
def filename_to_class(filename: str) -> str:
    angle = float(filename.split("_")[1].split(".")[0].replace("p", "."))
    if angle > 0:
        return "left"
    elif angle < 0:
        return "right"
    else:
        return "forward"

In [14]:
class ImageWithCmdDataset(Dataset):
    def __init__(self, filenames):
        """
        Creates objects for class labels, class indices, and filenames.
        
        :param filenames: (list) a list of filenames that make up the dataset
        """
        self.class_labels = ['left', 'forward', 'right']
        self.class_indices = {lbl:i for i, lbl in enumerate(self.class_labels)} # {'left': 0, 'forward': 1, 'right': 2}        
        self.all_filenames = filenames
        
    def __len__(self):
        """
        Gives length of dataset.
        
        :return: (int) the number of filenames in the dataset
        """
        return len(self.all_filenames)
    def __getitem__(self, index):
        """
        Gets the filename associated with the given index, opens the image at
        that index, then uses the image's filename to get information associated
        with the image such as its label and the label of the previous image.
        
        :param index: (int) number that represents the location of the desired data
        :return: (tuple) tuple of all the information associated with the desired data
        """
        # The filename of the image given a specific index
        img_filename = self.all_filenames[index]            
        
        # Opens image file and ensures dimension of channels included
        img = Image.open(img_filename).convert('RGB')
        # Resizes the image
        img = img.resize((224, 224))
        # Converts the image to tensor and 
        img = torch.Tensor(np.array(img)/255)
        # changes the order of the dimensions
        img = img.permute(2,0,1)
        
        # Getting the current image's label
        label_name = filename_to_class(img_filename)
        label = self.class_indices[label_name]
        
        # Getting the previous image's label
        # The default is 'forward'
        cmd_name = 'forward'
        
        # If the index is not 0, the cmd is determined by the previous img_filename
        if index != 0:
            prev_img_filename = self.all_filenames[index-1]
            cmd_name = filename_to_class(prev_img_filename)            
        cmd = self.class_indices[cmd_name]
        
        # Data and the label associated with that data
        return (img, cmd), label



In [16]:
dls = ImageDataLoaders.from_name_func(path1, files, filename_to_class, valid_pct = VALID_PCT)

In [17]:
#plt.savefig(get_fig_filename("batch", "pdf", rep))

In [18]:
class cmd_model(nn.Module):
    def __init__(self, arch: str, pretrained: bool):
        super(cmd_model, self).__init__()
        self.cnn = arch(pretrained=pretrained)
        
        self.fc1 = nn.Linear(self.cnn.fc.out_features + 1, 512)
        self.r1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(512, 3)
        
    def forward(self, data):
        print(data)
        img, cmd = data
        x1 = self.cnn(img)
        x2 = cmd.unsqueeze(1)
        
        x = torch.cat((x1, x2), dim=1)
        x = self.r1(self.fc1(x))
        x = self.fc2(x)
        return x

In [19]:
def prepare_dataloaders(dataset_name: str, prefix: str) -> DataLoaders:

    path = path1
    files = get_image_files(path1)
    
    # Get size of dataset and corresponding list of indices
    dataset_size = len(files)
    dataset_indices = list(range(dataset_size))
    
    # Shuffle the indices
    np.random.shuffle(dataset_indices)

    # Get the index for where we want to split the data
    val_split_index = int(np.floor(VALID_PCT * dataset_size))
    
    # Split the list of indices into training and validation indices
    train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]
    
    # Get the list of filenames for the training and validation sets
    train_filenames = [files[i] for i in train_idx]
    val_filenames = [files[i] for i in val_idx]
    
    # Create training and validation datasets
    train_data = ImageWithCmdDataset(train_filenames)
#     train_data.__get_item__(10)
    val_data = ImageWithCmdDataset(val_filenames)
    
    # Get DataLoader
    dls = DataLoaders.from_dsets(train_data, val_data)
    dls = dls.cuda()

    #dls.show_batch()  # type: ignore
    plt.savefig(get_fig_filename(prefix, "batch", "pdf", 0))

    return dls  # type: ignore

In [20]:
def train_model(
    dls: DataLoaders,
    model_arch: str,
    pretrained: bool,
    logname: Path,
    modelname: Path,
    prefix: str,
    rep: int,
):
    arch = compared_models[model_arch]
    net = cmd_model(arch, pretrained=pretrained)
    
    learn = Learner(
        dls,
        net,
        loss_func=CrossEntropyLossFlat(),
        metrics=accuracy,
        cbs=CSVLogger(fname=logname),
    )

    if pretrained:
        learn.fine_tune(NUM_EPOCHS)
    else:
        learn.fit_one_cycle(NUM_EPOCHS)

    # Save trained model
    torch.save(net.state_dict(), modelname)


In [21]:
learn = cnn_learner(dls, compared_models[model_name], metrics=accuracy, pretrained=use_pretraining, cbs=CSVLogger(fname=log_filename))

  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")


In [22]:
learn.path

Path('/home/dcad2021/artifacts/larger-perfect-dataset:v0/data')

In [23]:
if use_pretraining:
    learn.fine_tune(NUM_EPOCHS)
else:
    learn.fit_one_cycle(NUM_EPOCHS)

epoch,train_loss,valid_loss,accuracy,time
0,1.393035,0.564179,0.782609,00:34


epoch,train_loss,valid_loss,accuracy,time
0,0.58451,0.50949,0.858696,00:35
1,0.346169,0.186851,0.934783,00:35
2,0.212091,0.136459,0.934783,00:35


In [24]:
def main():

    arg_parser = ArgumentParser("Train cmd classification networks.")
    arg_parser.add_argument(
        "model_arch", help="Model architecture (see code for options)"
    )
    arg_parser.add_argument(
        "dataset_name", help="Name of dataset to use (corrected-wander-full)"
    )
    arg_parser.add_argument(
        "--pretrained", action="store_true", help="Use pretrained model"
    )

    args = arg_parser.parse_args()
    dls = prepare_dataloaders(dataset_path, fig_filename_prefix)
    
    # Train NUM_REPLICATES separate instances of this model and dataset
    for rep in range(NUM_REPLICATES):
         train_model(
            dls,
            args.model_arch,
            args.pretrained,
            log_filename,
            model_filename,
            fig_filename_prefix,
            rep,
        )

        