Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions 31_image_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,8 @@
"metadata": {},
"source": [
"After initialising the trainer instance, check whether a trained model already exists.\n",
"If so, load the weights using ```model_weights = torch.load(model_path, weights_only=True)```.\n",
"If so, load the weights using ```model_weights = torch.load(model_path, weights_only=True, map_location=torch.device('cpu'))```. \n",
"The ```map_location=torch.device('cpu')``` is only needed if you are running the code in a computer that does not have CUDA cores.\n",
"Then, load the weights into the model using (```model.load_state_dict(model_weights)```).\n",
"Finally, set the model to evaluation model (```model.eval()```).\n",
"This step is essential because certain layers, such as batch normalization and dropout, behave differently during training and evaluation.\n",
Expand Down Expand Up @@ -1974,7 +1975,7 @@
"Overfitting occurs when the model performs well on the training data but poorly on the validation data, usually indicated by a widening gap between the two curves.\n",
"Underfitting, on the other hand, is suggested when both the training and validation curves show poor performance and fail to improve. By monitoring these curves, we can adjust hyperparameters or modify the model architecture to address such issues. \n",
"\n",
"First, load the log file using ```pandas``` (```training_log = pd.read_csv(\"training_log.txt\")```).\n",
"First, load the log file using ```pandas``` (```training_log = pd.read_csv(\"training_log.txt\")``` or ```training_log = pd.read_csv(dataset_folder / \"training_log.txt\")``` in case you did not train the model by yourself).\n",
"Then, use the ```matplotlib``` library to plot the learning curves."
]
},
Expand All @@ -1988,7 +1989,7 @@
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# Load the training log file (In case you want to use the already trained model, replace this by model_path = dataset_folder / \"training_log.txt\")\n",
"# Load the training log file\n",
"training_log = None\n",
"\n",
"plt.figure()\n",
Expand Down