In [None]:
!pip3 install face_recognition

In [None]:
import sys
import cv2
import glob
import torch
import numpy as np
import pandas as pd
from torch import nn
import seaborn as sns
from torchvision import models
from google.colab import drive # type: ignore
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
from torch.utils.data.dataset import Dataset
drive.mount('/content/drive')

In [None]:
def frame_extract(path):
  vidObj = cv2.VideoCapture(path)
  success = 1
  while success:
      success, image = vidObj.read()
      if success:
          yield image

In [None]:
def validate_video(vid_path, train_transforms):
	transform = train_transforms
	count = 60
	video_path = vid_path
	frames = []
	for i, frame in enumerate(frame_extract(video_path)):
		frames.append(transform(frame))
		if(len(frames) == count):
			break
	frames = torch.stack(frames)
	frames = frames[:count]

	return frames

In [None]:
im_size = 112
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])

video_fil = glob.glob('/content/drive/My Drive/Celeb_fake_face_only/*.mp4')
video_fil += glob.glob('/content/drive/My Drive/Celeb_real_face_only/*.mp4')
video_fil += glob.glob('/content/drive/My Drive/DFDC_FAKE_Face_only_data/*.mp4')
video_fil += glob.glob('/content/drive/My Drive/DFDC_REAL_Face_only_data/*.mp4')
video_fil += glob.glob('/content/drive/My Drive/FF_Face_only_data/*.mp4')

print("Total no of videos :" , len(video_fil))

In [None]:
count = 0

for i in video_fil:
  try:
    count += 1
    validate_video(i, transforms)
  except:
    print("Number of video processed: " , count ," Remaining : " , (len(video_fil) - count))
    print("Corrupted video is : " , i)
    continue

print(f"Total Corrupted Videos: {(len(video_fil) - count)}")

In [None]:
frame_count = []

for video_file in video_fil:
  cap = cv2.VideoCapture(video_file)
  if(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) < 60):
    video_fil.remove(video_file)
    continue
  frame_count.append(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))

print("frames are " , frame_count)
print("Total no of video: " , len(frame_count))
print('Average frame per video:', np.mean(frame_count))

In [None]:
class video_dataset(Dataset):
	def __init__(self, video_names, labels, sequence_length = 60, transform = None):
			self.video_names = video_names
			self.labels = labels
			self.transform = transform
			self.count = sequence_length

	def __len__(self):
			return len(self.video_names)

	def __getitem__(self, idx):
			video_path = self.video_names[idx]
			frames = []
			temp_video = video_path.split('/')[-1]
			label = self.labels.iloc[(label.loc[label["file"] == temp_video].index.values[0]), 1]
			if(label == 'FAKE'):
				label = 0
			if(label == 'REAL'):
				label = 1
			for i, frame in enumerate(self.frame_extract(video_path)):
				frames.append(self.transform(frame))
				if(len(frames) == self.count):
					break
			frames = torch.stack(frames)
			frames = frames[:self.count]

			return frames, label

In [None]:
def number_of_real_and_fake_videos(data_list):
  header_list = ["file", "label"]
  lab = pd.read_csv('/content/drive/My Drive/Gobal_metadata.csv', names = header_list)
  fake = 0
  real = 0
  for i in data_list:
    temp_video = i.split('/')[-1]
    label = lab.iloc[(label.loc[label["file"] == temp_video].index.values[0]), 1]
    if(label == 'FAKE'):
      fake += 1
    if(label == 'REAL'):
      real += 1
  return real, fake

In [None]:
header_list = ["file", "label"]
labels = pd.read_csv('/content/drive/My Drive/Gobal_metadata.csv', names = header_list)

train_videos = video_fil[:int(0.8 * len(video_fil))]
test_videos = video_fil[int(0.8 * len(video_fil)):]

print("train : " , len(train_videos))
print("test : " , len(test_videos))

print("TRAIN: ", "Real:",number_of_real_and_fake_videos(train_videos)[0]," Fake:",number_of_real_and_fake_videos(train_videos)[1])
print("TEST: ", "Real:",number_of_real_and_fake_videos(test_videos)[0]," Fake:",number_of_real_and_fake_videos(test_videos)[1])

In [None]:
train_data = video_dataset(train_videos, labels, sequence_length = 60, transform = transforms)
test_data = video_dataset(test_videos, labels, sequence_length = 60, transform = transforms)

train_loader = DataLoader(train_data, batch_size = 4, shuffle = True, num_workers = 4)
test_loader = DataLoader(test_data, batch_size = 4, shuffle = True, num_workers = 4)

In [None]:
class Model(nn.Module):
	def __init__(self, num_classes, latent_dim = 2048, lstm_layers = 1, hidden_dim = 2048, bidirectional = False):
		super(Model, self).__init__()
		model = models.resnext50_32x4d(pretrained = True)
		self.model = nn.Sequential(*list(model.children())[:-2])
		self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
		self.relu = nn.LeakyReLU()
		self.dp = nn.Dropout(0.4)
		self.linear1 = nn.Linear(2048, num_classes)
		self.avgpool = nn.AdaptiveAvgPool2d(1)

	def forward(self, x):
		batch_size, seq_length, c, h, w = x.shape
		x = x.view(batch_size * seq_length, c, h, w)
		fmap = self.model(x)
		x = self.avgpool(fmap)
		x = x.view(batch_size, seq_length, 2048)
		x_lstm, _ = self.lstm(x, None)

		return fmap, self.dp(self.linear1(torch.mean(x_lstm, dim = 1)))

In [None]:
model = Model(2).cuda()

In [None]:
class AverageMeter(object):
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val):
		self.val = val
		self.sum += val
		self.count += 1
		self.avg = self.sum / self.count

In [None]:
def calculate_accuracy(outputs, targets):
	batch_size = targets.size(0)
	_, pred = outputs.topk(1, 1, True)
	pred = pred.t()
	correct = pred.eq(targets.view(1, -1))
	n_correct_elems = correct.float().sum().item()

	return 100 * n_correct_elems / batch_size

In [None]:
def train_epoch(epoch, num_epochs, data_loader, model, criterion, optimizer):
	model.train()
	losses = AverageMeter()
	accuracies = AverageMeter()

	for i, (inputs, targets) in enumerate(data_loader):
		if torch.cuda.is_available():
				targets = targets.type(torch.cuda.LongTensor)
				inputs = inputs.cuda()
		_,outputs = model(inputs)
		loss  = criterion(outputs, targets.type(torch.cuda.LongTensor))
		acc = calculate_accuracy(outputs, targets.type(torch.cuda.LongTensor))
		losses.update(loss.item(), inputs.size(0))
		accuracies.update(acc, inputs.size(0))
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		sys.stdout.write(
			"\r[Epoch %d/%d] [Batch %d / %d] [Loss: %f, Acc: %.2f%%]"
			% (
					epoch,
					num_epochs,
					i,
					len(data_loader),
					losses.avg,
					accuracies.avg))

	torch.save(model.state_dict(), '/content/checkpoint.pt')

	return losses.avg,accuracies.avg

In [None]:
def test(model, data_loader ,criterion):
	print('Testing')
	model.eval()
	losses = AverageMeter()
	accuracies = AverageMeter()
	pred = []
	true = []
	count = 0
	with torch.no_grad():
		for i, (inputs, targets) in enumerate(data_loader):
				if torch.cuda.is_available():
						targets = targets.cuda().type(torch.cuda.FloatTensor)
						inputs = inputs.cuda()
				_,outputs = model(inputs)
				loss = torch.mean(criterion(outputs, targets.type(torch.cuda.LongTensor)))
				acc = calculate_accuracy(outputs,targets.type(torch.cuda.LongTensor))
				_,p = torch.max(outputs, 1) 
				true += (targets.type(torch.cuda.LongTensor)).detach().cpu().numpy().reshape(len(targets)).tolist()
				pred += p.detach().cpu().numpy().reshape(len(p)).tolist()
				losses.update(loss.item(), inputs.size(0))
				accuracies.update(acc, inputs.size(0))
				sys.stdout.write(
					"\r[Batch %d / %d]  [Loss: %f, Acc: %.2f%%]"
					% (
							i,
							len(data_loader),
							losses.avg,
							accuracies.avg
							)
					)
		print('\nAccuracy {}'.format(accuracies.avg))

	return true, pred, losses.avg,accuracies.avg

In [None]:
def print_confusion_matrix(y_true, y_pred):
	cm = confusion_matrix(y_true, y_pred)
	print('True positive = ', cm[0][0])
	print('False positive = ', cm[0][1])
	print('False negative = ', cm[1][0])
	print('True negative = ', cm[1][1])
	print('\n')
	df_cm = pd.DataFrame(cm, range(2), range(2))
	sns.set(font_scale=1.4)
	sns.heatmap(df_cm, annot=True, annot_kws={"size": 16})
	plt.ylabel('Actual label', size = 20)
	plt.xlabel('Predicted label', size = 20)
	plt.xticks(np.arange(2), ['Fake', 'Real'], size = 16)
	plt.yticks(np.arange(2), ['Fake', 'Real'], size = 16)
	plt.ylim([2, 0])
	plt.show()
	calculated_acc = (cm[0][0] + cm[1][1])/(cm[0][0] + cm[0][1] + cm[1][0] + cm[1][1])
	print("Calculated Accuracy", calculated_acc * 100)

In [None]:
def plot_loss(train_loss_avg, test_loss_avg, num_epochs):
  loss_train = train_loss_avg
  loss_val = test_loss_avg
  print(num_epochs)
  epochs = range(1, num_epochs + 1)
  plt.plot(epochs, loss_train, 'g', label = 'Training loss')
  plt.plot(epochs, loss_val, 'b', label = 'validation loss')
  plt.title('Training and Validation loss')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.legend()
  plt.show()

In [None]:
def plot_accuracy(train_accuracy, test_accuracy, num_epochs):
  loss_train = train_accuracy
  loss_val = test_accuracy
  epochs = range(1, num_epochs + 1)
  plt.plot(epochs, loss_train, 'g', label = 'Training accuracy')
  plt.plot(epochs, loss_val, 'b', label = 'validation accuracy')
  plt.title('Training and Validation accuracy')
  plt.xlabel('Epochs')
  plt.ylabel('Accuracy')
  plt.legend()
  plt.show()

In [None]:
lr = 1e-5
num_epochs = 20

optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = 1e-5)

criterion = nn.CrossEntropyLoss().cuda()
train_loss_avg =[]
train_accuracy = []
test_loss_avg = []
test_accuracy = []

for epoch in range(1,num_epochs+1):
	l, acc = train_epoch(epoch, num_epochs, train_loader, model, criterion, optimizer)
	train_loss_avg.append(l)
	train_accuracy.append(acc)
	true,pred,tl,t_acc = test(model, test_loader, criterion)
	test_loss_avg.append(tl)
	test_accuracy.append(t_acc)

In [None]:
plot_loss(train_loss_avg, test_loss_avg, len(train_loss_avg))
plot_accuracy(train_accuracy, test_accuracy, len(train_accuracy))
print_confusion_matrix(true, pred)