You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello i was following the tutorial but in anaconda. until now everything was ok until i was at the 4:51:36 and realized that i don't know how the plot shown. I running this from a .py file
like this:
import torch
from torch import nn
import matplotlib.pyplot as plt
weight = 0.7
bias = 0.3
start = 0
end = 1
step = 0.02
X = torch.arange(start, end, step).unsqueeze(dim=1)
y = weight * X + bias
def plot_predictions(train_data=X_train,
train_labels=y_train,
test_data=X_test,
test_labels=y_test,
predictions=None):
"""
Plots training data, test data and compares predictions.
"""
plt.figure(figsize=(10, 7))
if predictions is not None:
# Plot the predictions in red (predictions were made on the test data)
plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
Show the legend
plt.legend(prop={"size": 14});
where am i supposed to call the plot_predictions();
at the end of the code ?
at the prompt ?
or i need something else ?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello i was following the tutorial but in anaconda. until now everything was ok until i was at the 4:51:36 and realized that i don't know how the plot shown. I running this from a .py file
like this:
import torch
from torch import nn
import matplotlib.pyplot as plt
weight = 0.7
bias = 0.3
start = 0
end = 1
step = 0.02
X = torch.arange(start, end, step).unsqueeze(dim=1)
y = weight * X + bias
train_split = int(0.8 * len(X))
X_train, y_train = X[:train_split], y[:train_split]
X_test, y_test = X[train_split:], y[train_split:]
def plot_predictions(train_data=X_train,
train_labels=y_train,
test_data=X_test,
test_labels=y_test,
predictions=None):
"""
Plots training data, test data and compares predictions.
"""
plt.figure(figsize=(10, 7))
Plot training data in blue
plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
Plot test data in green
plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
if predictions is not None:
# Plot the predictions in red (predictions were made on the test data)
plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
Show the legend
plt.legend(prop={"size": 14});
where am i supposed to call the plot_predictions();
at the end of the code ?
at the prompt ?
or i need something else ?
Thanks for your time and sorry for my english.
Beta Was this translation helpful? Give feedback.
All reactions