Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 59 additions & 22 deletions vertexai/ImagenScreen/ImagenScreen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

import SwiftUI
import GenerativeAIUIComponents

struct ImagenScreen: View {
@StateObject var viewModel = ImagenViewModel()
Expand All @@ -25,29 +26,38 @@ struct ImagenScreen: View {
var focusedField: FocusedField?

var body: some View {
VStack {
TextField("Enter a prompt to generate an image", text: $viewModel.userInput)
.focused($focusedField, equals: .message)
.textFieldStyle(.roundedBorder)
.onSubmit {
onGenerateTapped()
ZStack {
VStack {
InputField("Enter a prompt to generate an image", text: $viewModel.userInput) {
Image(
systemName: viewModel.inProgress ? "stop.circle.fill" : "paperplane.circle.fill"
)
.font(.title)
}
.padding()
.focused($focusedField, equals: .message)
.onSubmit { sendOrStop() }

Button("Generate") {
onGenerateTapped()
ScrollView {
let spacing: CGFloat = 10
LazyVGrid(columns: [
GridItem(.fixed(UIScreen.main.bounds.width / 2 - spacing), spacing: spacing),
GridItem(.fixed(UIScreen.main.bounds.width / 2 - spacing), spacing: spacing),
], spacing: spacing) {
ForEach(viewModel.images, id: \.self) { image in
Image(uiImage: image)
.resizable()
.aspectRatio(contentMode: .fill)
.frame(width: UIScreen.main.bounds.width / 2 - spacing,
height: UIScreen.main.bounds.width / 2 - spacing)
.cornerRadius(12)
.clipped()
}
}
.padding(.horizontal, spacing)
}
}
.padding()
if viewModel.inProgress {
Text("Waiting for model response ...")
}
ForEach(viewModel.images, id: \.self) {
Image(uiImage: $0)
.resizable()
.scaledToFill()
.frame(minWidth: 0, maxWidth: .infinity, minHeight: 0, maxHeight: .infinity)
.aspectRatio(nil, contentMode: .fit)
.clipped()
ProgressOverlay()
}
}
.navigationTitle("Imagen sample")
Expand All @@ -56,11 +66,38 @@ struct ImagenScreen: View {
}
}

private func onGenerateTapped() {
focusedField = nil

private func sendMessage() {
Task {
await viewModel.generateImage(prompt: viewModel.userInput)
focusedField = .message
}
}

private func sendOrStop() {
if viewModel.inProgress {
viewModel.stop()
} else {
sendMessage()
}
}
}

struct ProgressOverlay: View {
var body: some View {
ZStack {
RoundedRectangle(cornerRadius: 16)
.fill(Material.ultraThinMaterial)
.frame(width: 120, height: 100)
.shadow(radius: 8)

VStack(spacing: 12) {
ProgressView()
.scaleEffect(1.5)

Text("Loading...")
.font(.subheadline)
.foregroundColor(.secondary)
}
}
}
}
Expand Down
44 changes: 28 additions & 16 deletions vertexai/ImagenScreen/ImagenViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class ImagenViewModel: ObservableObject {

private let model: ImagenModel

private var generateImagesTask: Task<Void, Never>?

// 1. Initialize the Vertex AI service
private let vertexAI = VertexAI.vertexAI()

Expand All @@ -57,27 +59,37 @@ class ImagenViewModel: ObservableObject {
}

func generateImage(prompt: String) async {
guard !inProgress else {
print("Already generating images...")
return
}
do {
stop()

generateImagesTask = Task {
inProgress = true
defer {
inProgress = false
}
inProgress = true
// 4. Call generateImages with the text prompt
let response = try await model.generateImages(prompt: prompt)

// 5. Print the reason images were filtered out, if any.
if let filteredReason = response.filteredReason {
print("Image(s) Blocked: \(filteredReason)")
}
do {
// 4. Call generateImages with the text prompt
let response = try await model.generateImages(prompt: prompt)

// 5. Print the reason images were filtered out, if any.
if let filteredReason = response.filteredReason {
print("Image(s) Blocked: \(filteredReason)")
}

// 6. Convert the image data to UIImage for display in the UI
images = response.images.compactMap { UIImage(data: $0.data) }
} catch {
logger.error("Error generating images: \(error)")
if !Task.isCancelled {
// 6. Convert the image data to UIImage for display in the UI
images = response.images.compactMap { UIImage(data: $0.data) }
}
} catch {
if !Task.isCancelled {
logger.error("Error generating images: \(error)")
}
}
}
}

func stop() {
generateImagesTask?.cancel()
generateImagesTask = nil
}
}