Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT - training overhaul - wip #21

Merged
merged 13 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ models/gaze_model_pytorch_vgg16_prl_mpii_allsubjects1.model
*.pt
*.dat
*.pth
output_images/*.*
46 changes: 5 additions & 41 deletions EmoDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,13 @@ def __init__(self, use_gpu:False, sample_rate: int, n_sample_frames: int, width:
self.video_dir = video_dir
self.transform = transform
self.stage = stage
# self.face_alignment = face_alignment.FaceAlignment(face_alignment.LandmarksType.THREE_D, device='cpu')

# self.feature_extractor = Wav2VecFeatureExtractor(model_name='facebook/wav2vec2-base-960h', device='cuda')
# self.face_mask_generator = FaceHelper()
# self.pixel_transform = transforms.Compose(
# [
# transforms.RandomResizedCrop(
# (height, width),
# scale=self.img_scale,
# ratio=self.img_ratio,
# interpolation=transforms.InterpolationMode.BILINEAR,
# ),
# transforms.ToTensor(),
# transforms.Normalize([0.5], [0.5]),
# ]
# )
# Reduce 512 images -> 256
self.pixel_transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
transforms.RandomResizedCrop(
(height, width),
scale=self.img_scale,
ratio=self.img_ratio,
interpolation=transforms.InterpolationMode.BILINEAR,
),
# transforms.Resize((256, 256)), - just go HQ
transforms.ToTensor(),
# transforms.Normalize([0.5], [0.5]), - this makes picture go red
]
)

Expand All @@ -82,9 +56,7 @@ def __init__(self, use_gpu:False, sample_rate: int, n_sample_frames: int, width:
for frame_idx in range(video_length):
# Read frame and convert to PIL Image
frame = Image.fromarray(video_drv_reader[frame_idx].numpy())


# Transform the frame
# Transform the frame
state = torch.get_rng_state()
pixel_values_frame = self.augmentation(frame, self.pixel_transform, state)
self.driving_vid_pil_image_list.append(pixel_values_frame)
Expand All @@ -104,13 +76,13 @@ def augmentation(self, images, transform, state=None):
return ret_tensor

def __getitem__(self, index: int) -> Dict[str, Any]:
# print("__getitem__")
print("__getitem__")
video_id = self.video_ids[index]
mp4_path = os.path.join(self.video_dir, f"{video_id}.mp4")


video_reader = VideoReader(mp4_path, ctx=self.ctx)
video_length = 2 # frames len(video_reader)
video_length = len(video_reader)


vid_pil_image_list = []
Expand All @@ -121,14 +93,6 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
# Read frame and convert to PIL Image
frame = Image.fromarray(video_reader[frame_idx].numpy())


# Detect keypoints using face_alignment - NOT USED BY Megaportraits
# keypoints = self.face_alignment.get_landmarks(video_reader[frame_idx].numpy())
# if keypoints is not None:
# keypoints_list.append(keypoints[0])
# else:
# keypoints_list.append(None)

# Transform the frame
state = torch.get_rng_state()
pixel_values_frame = self.augmentation(frame, self.pixel_transform, state)
Expand Down
9 changes: 8 additions & 1 deletion configs/training/stage1-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,11 @@ training:
video_dir: '/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666'
sample_rate: 25
n_sample_frames: 100
json_file: './data/overfit.json'
json_file: './data/overfit.json'


w_per: 20 # perceptual loss
w_adv: 1 # adversarial loss
w_fm: 40 # feature matching loss
w_cos: 2 # cycle consistency loss

92 changes: 65 additions & 27 deletions draw_warps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,69 @@
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

# Create tensors for x, y, and z coordinates
x = torch.linspace(0, 10, 50)
y = torch.linspace(0, 10, 50)
X, Y = torch.meshgrid(x, y)
Z1 = torch.sin(X) + torch.randn(X.shape) * 0.2
Z2 = torch.sin(X + 1.5) + torch.randn(X.shape) * 0.2
Z3 = Z1 + Z2

# Create a figure and 3D axis
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

# Plot the dots with quiver for direction/flow
q1 = ax.quiver(X, Y, Z1, Z1, Z1, Z1, length=0.1, normalize=True, cmap='viridis', label='x_e,k')
q2 = ax.quiver(X, Y, Z2, Z2, Z2, Z2, length=0.1, normalize=True, cmap='plasma', label='R_d+c,k')
q3 = ax.quiver(X, Y, Z3, Z3, Z3, Z3, length=0.1, normalize=True, cmap='inferno', label='R_d+c,k + t_d')

# Set labels and title
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_title('PyTorch Tensor Plot (3D)')

# Add a legend
ax.legend()

# Display the plot
# # Create tensors for x, y, and z coordinates
# x = torch.linspace(0, 10, 50)
# y = torch.linspace(0, 10, 50)
# X, Y = torch.meshgrid(x, y)
# Z1 = torch.sin(X) + torch.randn(X.shape) * 0.2
# Z2 = torch.sin(X + 1.5) + torch.randn(X.shape) * 0.2
# Z3 = Z1 + Z2

# # Create a figure and 3D axis
# fig = plt.figure(figsize=(8, 6))
# ax = fig.add_subplot(111, projection='3d')

# # Plot the dots with quiver for direction/flow
# q1 = ax.quiver(X, Y, Z1, Z1, Z1, Z1, length=0.1, normalize=True, cmap='viridis', label='x_e,k')
# q2 = ax.quiver(X, Y, Z2, Z2, Z2, Z2, length=0.1, normalize=True, cmap='plasma', label='R_d+c,k')
# q3 = ax.quiver(X, Y, Z3, Z3, Z3, Z3, length=0.1, normalize=True, cmap='inferno', label='R_d+c,k + t_d')

# # Set labels and title
# ax.set_xlabel('x')
# ax.set_ylabel('y')
# ax.set_zlabel('z')
# ax.set_title('PyTorch Tensor Plot (3D)')

# # Add a legend
# ax.legend()

# # Display the plot
# plt.show()


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d

import torch
import numpy as np
import torch.nn.functional as F


k = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
dtype=torch.float32)
base = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True)

k = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0]],
dtype=torch.float32) # rotate
grid = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True)
grid = grid - base
grid = grid[0]

D, H, W, _ = grid.shape

fig = plt.figure()
ax = fig.add_subplot(projection="3d")

k, j, i = np.meshgrid(
np.arange(0, D, 1),
np.arange(0, H, 1),
np.arange(0, W, 1),
indexing="ij",
)

u = grid[..., 0].numpy()
v = grid[..., 1].numpy()
w = grid[..., 2].numpy()

ax.quiver(k, j, i, w, v, u, length=0.3)
plt.show()
Loading